pax_global_header00006660000000000000000000000064147600363540014522gustar00rootroot0000000000000052 comment=8076c75051f943c29bae79a2a32d61eda48e582e workflow-0.11.8/000077500000000000000000000000001476003635400134635ustar00rootroot00000000000000workflow-0.11.8/.editorconfig000066400000000000000000000003161476003635400161400ustar00rootroot00000000000000# top-most EditorConfig file root = true # all files [*] indent_style = tab indent_size = 4 [src/kernel/list.h] indent_style = tab indent_size = 8 [src/kernel/rbtree.*] indent_style = tab indent_size = 8 workflow-0.11.8/.github/000077500000000000000000000000001476003635400150235ustar00rootroot00000000000000workflow-0.11.8/.github/workflows/000077500000000000000000000000001476003635400170605ustar00rootroot00000000000000workflow-0.11.8/.github/workflows/ci.yml000066400000000000000000000017341476003635400202030ustar00rootroot00000000000000name: ci build on: push: branches: [ master ] pull_request: branches: [ master ] jobs: ubuntu-cmake: name: ubuntu runs-on: ubuntu-latest steps: - name: setup run: | sudo apt-get update sudo apt-get install cmake g++ libgtest-dev make libssl-dev sudo apt-get install redis valgrind - uses: actions/checkout@v2 - name: make run: make - name: make check run: make check - name: make tutorial run: make tutorial fedora-cmake: name: fedora runs-on: ubuntu-latest container: image: fedora:latest steps: - uses: actions/checkout@v3 - run: cat /etc/os-release - name: install dependencies run: | dnf -y update dnf -y install cmake gcc-c++ gtest-devel make dnf -y install openssl-devel redis valgrind - name: make run: make - name: make check run: make check - name: make tutorial run: make tutorial workflow-0.11.8/.github/workflows/xmake.yml000066400000000000000000000013051476003635400207070ustar00rootroot00000000000000name: xmake build on: workflow_dispatch: jobs: build: name: build-linux runs-on: ubuntu-latest steps: - name: install dependencies run: | sudo apt-get update sudo apt-get install -y g++ libssl-dev libgtest-dev - name: setup xmake uses: xmake-io/github-action-setup-xmake@v1 with: xmake-version: latest - name: pull code uses: actions/checkout@v2 - name: xmake run: | xmake -r xmake -g test xmake -g tutorial xmake -g benchmark - name : run shared run: | xmake f -k shared xmake -r xmake -g test xmake -g tutorial xmake -g benchmark workflow-0.11.8/.gitignore000066400000000000000000000006601476003635400154550ustar00rootroot00000000000000# Prerequisites *.d # Compiled Object files *.slo *.lo *.o *.obj # Precompiled Headers *.gch *.pch # Compiled Dynamic libraries *.so *.so.* *.dylib *.dll # Fortran module files *.mod *.smod # Compiled Static libraries *.lai *.la *.a *.lib # Executables *.exe *.out *.app # bazel env bazel-* # vscode configs .vscode # idea configs .idea cmake-build-debug/ workflow-config.cmake # xmake configs .xmake build.xmake _include workflow-0.11.8/BUILD000066400000000000000000000202531476003635400142470ustar00rootroot00000000000000config_setting( name = 'linux', constraint_values = [ "@platforms//os:linux", ], visibility = ['//visibility:public'], ) cc_library( name = 'workflow_hdrs', hdrs = glob(['src/include/workflow/*']), includes = ['src/include'], visibility = ["//visibility:public"], linkopts = [ '-lpthread', '-lssl', '-lcrypto', ], ) cc_library( name = 'common_c', srcs = [ 'src/kernel/mpoller.c', 'src/kernel/msgqueue.c', 'src/kernel/poller.c', 'src/kernel/rbtree.c', 'src/kernel/thrdpool.c', 'src/util/crc32c.c', 'src/util/json_parser.c', ], hdrs = glob(['src/*/*.h']) + glob(['src/*/*.inl']), includes = [ 'src/kernel', 'src/util', ], copts = ['-std=gnu90'], visibility = ["//visibility:public"], ) cc_library( name = 'common', srcs = [ 'src/client/WFDnsClient.cc', 'src/factory/DnsTaskImpl.cc', 'src/factory/FileTaskImpl.cc', 'src/factory/WFGraphTask.cc', 'src/factory/WFResourcePool.cc', 'src/factory/WFMessageQueue.cc', 'src/factory/WFTaskFactory.cc', 'src/factory/Workflow.cc', 'src/manager/DnsCache.cc', 'src/manager/RouteManager.cc', 'src/manager/WFGlobal.cc', 'src/nameservice/WFDnsResolver.cc', 'src/nameservice/WFNameService.cc', 'src/protocol/TLVMessage.cc', 'src/protocol/DnsMessage.cc', 'src/protocol/DnsUtil.cc', 'src/protocol/SSLWrapper.cc', 'src/protocol/PackageWrapper.cc', 'src/protocol/dns_parser.c', 'src/server/WFServer.cc', 'src/kernel/CommRequest.cc', 'src/kernel/CommScheduler.cc', 'src/kernel/Communicator.cc', 'src/kernel/Executor.cc', 'src/kernel/SubTask.cc', ] + select({ ':linux': [ 'src/kernel/IOService_linux.cc', ], '//conditions:default': [ 'src/kernel/IOService_thread.cc', ], }) + glob(['src/util/*.cc']), hdrs = glob(['src/*/*.h']) + glob(['src/*/*.inl']), includes = [ 'src/algorithm', 'src/client', 'src/factory', 'src/kernel', 'src/manager', 'src/nameservice', 'src/protocol', 'src/server', 'src/util', ], deps = ['workflow_hdrs', 'common_c'], visibility = ["//visibility:public"], ) cc_library( name = 'http', hdrs = [ 'src/protocol/HttpMessage.h', 'src/protocol/HttpUtil.h', 'src/protocol/http_parser.h', 'src/server/WFHttpServer.h', ], includes = [ 'src/protocol', 'src/server', ], srcs = [ 'src/factory/HttpTaskImpl.cc', 'src/protocol/HttpMessage.cc', 'src/protocol/HttpUtil.cc', 'src/protocol/http_parser.c', ], deps = [ ':common', ], visibility = ["//visibility:public"], ) cc_library( name = 'redis', hdrs = [ 'src/factory/RedisTaskImpl.inl', 'src/protocol/RedisMessage.h', 'src/protocol/redis_parser.h', 'src/server/WFRedisServer.h', 'src/client/WFRedisSubscriber.h', ], includes = [ 'src/protocol', 'src/server', ], srcs = [ 'src/factory/RedisTaskImpl.cc', 'src/protocol/RedisMessage.cc', 'src/protocol/redis_parser.c', 'src/client/WFRedisSubscriber.cc', ], deps = [ ':common', ], visibility = ["//visibility:public"], ) cc_library( name = 'mysql', hdrs = [ 'src/protocol/MySQLMessage.h', 'src/protocol/MySQLMessage.inl', 'src/protocol/MySQLResult.h', 'src/protocol/MySQLResult.inl', 'src/protocol/MySQLUtil.h', 'src/protocol/mysql_byteorder.h', 'src/protocol/mysql_parser.h', 'src/protocol/mysql_stream.h', 'src/protocol/mysql_types.h', 'src/server/WFMySQLServer.h', 'src/client/WFMySQLConnection.h', ], includes = [ 'src/protocol', 'src/client', 'src/server', ], srcs = [ 'src/factory/MySQLTaskImpl.cc', 'src/protocol/MySQLMessage.cc', 'src/protocol/MySQLResult.cc', 'src/protocol/MySQLUtil.cc', 'src/protocol/mysql_byteorder.c', 'src/protocol/mysql_parser.c', 'src/protocol/mysql_stream.c', 'src/client/WFMySQLConnection.cc', ], deps = [ ':common', ], visibility = ["//visibility:public"], ) cc_library( name = 'upstream', hdrs = [ 'src/manager/UpstreamManager.h', 'src/nameservice/UpstreamPolicies.h', 'src/nameservice/WFServiceGovernance.h', ], includes = [ 'src/manager', 'src/nameservice', ], srcs = [ 'src/manager/UpstreamManager.cc', 'src/nameservice/UpstreamPolicies.cc', 'src/nameservice/WFServiceGovernance.cc', ], deps = [ ':common', ], visibility = ["//visibility:public"], ) cc_library( name = 'kafka_message', hdrs = [ 'src/factory/KafkaTaskImpl.inl', 'src/protocol/KafkaDataTypes.h', 'src/protocol/KafkaMessage.h', 'src/protocol/KafkaResult.h', 'src/protocol/kafka_parser.h', ], includes = [ 'src/factory', 'src/protocol', ], srcs = [ 'src/factory/KafkaTaskImpl.cc', 'src/protocol/KafkaMessage.cc', ], copts = ['-fno-rtti'], deps = [ ':common', ], ) cc_library( name = 'kafka', hdrs = [ 'src/client/WFKafkaClient.h', 'src/factory/KafkaTaskImpl.inl', 'src/protocol/KafkaDataTypes.h', 'src/protocol/KafkaMessage.h', 'src/protocol/KafkaResult.h', 'src/protocol/kafka_parser.h', ], includes = [ 'src/client', 'src/factory', 'src/protocol', ], srcs = [ 'src/client/WFKafkaClient.cc', 'src/protocol/KafkaDataTypes.cc', 'src/protocol/KafkaResult.cc', 'src/protocol/kafka_parser.c', ], deps = [ ':common', ':kafka_message', ], visibility = ["//visibility:public"], linkopts = [ '-lsnappy', '-llz4', '-lz', '-lzstd', ], ) cc_library( name = 'consul', hdrs = [ 'src/client/WFConsulClient.h', 'src/protocol/ConsulDataTypes.h', ], includes = [ 'src/client', 'src/factory', 'src/protocol', 'src/util', ], srcs = [ 'src/client/WFConsulClient.cc', ], deps = [ ':common', ':http', ], visibility = ["//visibility:public"], ) cc_binary( name = 'helloworld', srcs = ['tutorial/tutorial-00-helloworld.cc'], deps = [':http'], ) cc_binary( name = 'wget', srcs = ['tutorial/tutorial-01-wget.cc'], deps = [':http'], ) cc_binary( name = 'redis_cli', srcs = ['tutorial/tutorial-02-redis_cli.cc'], deps = [':redis'], ) cc_binary( name = 'wget_to_redis', srcs = ['tutorial/tutorial-03-wget_to_redis.cc'], deps = [':http', 'redis'], ) cc_binary( name = 'http_echo_server', srcs = ['tutorial/tutorial-04-http_echo_server.cc'], deps = [':http'], ) cc_binary( name = 'http_proxy', srcs = ['tutorial/tutorial-05-http_proxy.cc'], deps = [':http'], ) cc_binary( name = 'parallel_wget', srcs = ['tutorial/tutorial-06-parallel_wget.cc'], deps = [':http'], ) cc_binary( name = 'sort_task', srcs = ['tutorial/tutorial-07-sort_task.cc'], deps = [':common'], ) cc_binary( name = 'matrix_multiply', srcs = ['tutorial/tutorial-08-matrix_multiply.cc'], deps = [':common'], ) cc_binary( name = 'http_file_server', srcs = ['tutorial/tutorial-09-http_file_server.cc'], deps = [':http'], ) cc_library( name = 'user_hdrs', hdrs = ['tutorial/tutorial-10-user_defined_protocol/message.h'], includes = ['tutorial/tutorial-10-user_defined_protocol'], ) cc_binary( name = 'server', srcs = [ 'tutorial/tutorial-10-user_defined_protocol/server.cc', 'tutorial/tutorial-10-user_defined_protocol/message.cc', ], deps = [':common', ':user_hdrs'], ) cc_binary( name = 'client', srcs = [ 'tutorial/tutorial-10-user_defined_protocol/client.cc', 'tutorial/tutorial-10-user_defined_protocol/message.cc', ], deps = [':common', ':user_hdrs'], ) cc_binary( name = 'graph_task', srcs = ['tutorial/tutorial-11-graph_task.cc'], deps = [':http'], ) cc_binary( name = 'mysql_cli', srcs = ['tutorial/tutorial-12-mysql_cli.cc'], deps = [':mysql'], ) cc_binary( name = 'kafka_cli', srcs = ['tutorial/tutorial-13-kafka_cli.cc'], deps = [':kafka', ':workflow_hdrs'], ) cc_binary( name = 'consul_cli', srcs = ['tutorial/tutorial-14-consul_cli.cc'], deps = [':consul'], ) cc_binary( name = 'name_service', srcs = ['tutorial/tutorial-15-name_service.cc'], deps = [':http'], ) cc_binary( name = 'graceful_restart_bootstrap', srcs = [ 'tutorial/tutorial-16-graceful_restart/bootstrap.c', ], ) cc_binary( name = 'graceful_restart_server', srcs = [ 'tutorial/tutorial-16-graceful_restart/server.cc', ], deps = [':http'], ) cc_binary( name = 'dns_cli', srcs = ['tutorial/tutorial-17-dns_cli.cc'], deps = [':common'], ) cc_binary( name = 'redis_subscriber', srcs = ['tutorial/tutorial-18-redis_subscriber.cc'], deps = [':redis'], ) cc_binary( name = 'dns_server', srcs = ['tutorial/tutorial-19-dns_server.cc'], deps = [':common'], ) workflow-0.11.8/CMakeLists.txt000066400000000000000000000071261476003635400162310ustar00rootroot00000000000000cmake_minimum_required(VERSION 3.6) set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "build type") set(CMAKE_SKIP_RPATH TRUE) project( workflow VERSION 0.11.8 LANGUAGES C CXX ) if (CYGWIN) message(FATAL_ERROR "Sorry, DO NOT support Cygwin") endif () if (MINGW) message(FATAL_ERROR "Sorry, DO NOT support MinGW") endif () include(GNUInstallDirs) set(CMAKE_CONFIG_INSTALL_FILE ${PROJECT_BINARY_DIR}/config.toinstall.cmake) set(CMAKE_CONFIG_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}) set(INC_DIR ${PROJECT_SOURCE_DIR}/_include CACHE PATH "workflow inc") set(LIB_DIR ${PROJECT_SOURCE_DIR}/_lib CACHE PATH "workflow lib") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${LIB_DIR}) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${LIB_DIR}) add_custom_target( LINK_HEADERS ALL COMMENT "link headers..." ) INCLUDE(CMakeLists_Headers.txt) macro(makeLink src dest target) add_custom_command( TARGET ${target} PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${src} ${dest} DEPENDS ${dest} ) endmacro() add_custom_command( TARGET LINK_HEADERS PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory ${INC_DIR}/${PROJECT_NAME} ) foreach(header_file ${INCLUDE_HEADERS} ${INCLUDE_KERNEL_HEADERS}) string(REPLACE "/" ";" arr ${header_file}) list(GET arr -1 file_name) makeLink(${PROJECT_SOURCE_DIR}/${header_file} ${INC_DIR}/${PROJECT_NAME}/${file_name} LINK_HEADERS) endforeach() message("CMAKE_C_FLAGS_DEBUG is ${CMAKE_C_FLAGS_DEBUG}") message("CMAKE_C_FLAGS_RELEASE is ${CMAKE_C_FLAGS_RELEASE}") message("CMAKE_C_FLAGS_RELWITHDEBINFO is ${CMAKE_C_FLAGS_RELWITHDEBINFO}") message("CMAKE_C_FLAGS_MINSIZEREL is ${CMAKE_C_FLAGS_MINSIZEREL}") message("CMAKE_CXX_FLAGS_DEBUG is ${CMAKE_CXX_FLAGS_DEBUG}") message("CMAKE_CXX_FLAGS_RELEASE is ${CMAKE_CXX_FLAGS_RELEASE}") message("CMAKE_CXX_FLAGS_RELWITHDEBINFO is ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}") message("CMAKE_CXX_FLAGS_MINSIZEREL is ${CMAKE_CXX_FLAGS_MINSIZEREL}") if (WIN32) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /MP /wd4200") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /wd4200 /std:c++14") else () set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -fPIC -pipe -std=gnu90") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -fPIC -pipe -std=c++11 -fno-exceptions -Wno-invalid-offsetof") if (APPLE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated-declarations") endif() endif () add_subdirectory(src) ####CONFIG include(CMakePackageConfigHelpers) set(CONFIG_INC_DIR ${INC_DIR}) set(CONFIG_LIB_DIR ${LIB_DIR}) configure_package_config_file( ${PROJECT_NAME}-config.cmake.in ${PROJECT_SOURCE_DIR}/${PROJECT_NAME}-config.cmake INSTALL_DESTINATION ${CMAKE_CONFIG_INSTALL_DIR} PATH_VARS CONFIG_INC_DIR CONFIG_LIB_DIR ) set(CONFIG_INC_DIR ${CMAKE_INSTALL_INCLUDEDIR}) set(CONFIG_LIB_DIR ${CMAKE_INSTALL_LIBDIR}) configure_package_config_file( ${PROJECT_NAME}-config.cmake.in ${CMAKE_CONFIG_INSTALL_FILE} INSTALL_DESTINATION ${CMAKE_CONFIG_INSTALL_DIR} PATH_VARS CONFIG_INC_DIR CONFIG_LIB_DIR ) write_basic_package_version_file( ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}-config-version.cmake VERSION ${WORKFLOW_VERSION} COMPATIBILITY AnyNewerVersion ) install( FILES ${CMAKE_CONFIG_INSTALL_FILE} DESTINATION ${CMAKE_CONFIG_INSTALL_DIR} COMPONENT devel RENAME ${PROJECT_NAME}-config.cmake ) install( FILES ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}-config-version.cmake DESTINATION ${CMAKE_CONFIG_INSTALL_DIR} COMPONENT devel ) install( FILES ${INCLUDE_HEADERS} ${INCLUDE_KERNEL_HEADERS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME} COMPONENT devel ) install( FILES README.md DESTINATION "${CMAKE_INSTALL_DOCDIR}-${PROJECT_VERSION}" COMPONENT devel ) workflow-0.11.8/CMakeLists_Headers.txt000066400000000000000000000053451476003635400176650ustar00rootroot00000000000000cmake_minimum_required(VERSION 3.6) set(COMMON_KERNEL_HEADERS src/kernel/CommRequest.h src/kernel/CommScheduler.h src/kernel/Communicator.h src/kernel/SleepRequest.h src/kernel/ExecRequest.h src/kernel/IORequest.h src/kernel/Executor.h src/kernel/list.h src/kernel/mpoller.h src/kernel/poller.h src/kernel/msgqueue.h src/kernel/rbtree.h src/kernel/SubTask.h src/kernel/thrdpool.h ) if (CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") set(INCLUDE_KERNEL_HEADERS ${COMMON_KERNEL_HEADERS} src/kernel/IOService_linux.h ) elseif (UNIX) set(INCLUDE_KERNEL_HEADERS ${COMMON_KERNEL_HEADERS} src/kernel/IOService_thread.h ) else () message(FATAL_ERROR "IOService unsupported.") endif () set(INCLUDE_HEADERS src/protocol/ProtocolMessage.h src/protocol/http_parser.h src/protocol/HttpMessage.h src/protocol/HttpUtil.h src/protocol/redis_parser.h src/protocol/RedisMessage.h src/protocol/mysql_stream.h src/protocol/MySQLMessage.h src/protocol/MySQLMessage.inl src/protocol/MySQLResult.h src/protocol/MySQLResult.inl src/protocol/MySQLUtil.h src/protocol/mysql_parser.h src/protocol/mysql_types.h src/protocol/mysql_byteorder.h src/protocol/PackageWrapper.h src/protocol/SSLWrapper.h src/protocol/dns_parser.h src/protocol/DnsMessage.h src/protocol/DnsUtil.h src/protocol/TLVMessage.h src/protocol/ConsulDataTypes.h src/server/WFServer.h src/server/WFDnsServer.h src/server/WFHttpServer.h src/server/WFRedisServer.h src/server/WFMySQLServer.h src/client/WFMySQLConnection.h src/client/WFRedisSubscriber.h src/client/WFConsulClient.h src/client/WFDnsClient.h src/manager/DnsCache.h src/manager/WFGlobal.h src/manager/UpstreamManager.h src/manager/RouteManager.h src/manager/EndpointParams.h src/manager/WFFuture.h src/manager/WFFacilities.h src/manager/WFFacilities.inl src/util/json_parser.h src/util/EncodeStream.h src/util/LRUCache.h src/util/StringUtil.h src/util/URIParser.h src/factory/WFConnection.h src/factory/WFTask.h src/factory/WFTask.inl src/factory/WFGraphTask.h src/factory/WFTaskError.h src/factory/WFTaskFactory.h src/factory/WFTaskFactory.inl src/factory/WFAlgoTaskFactory.h src/factory/WFAlgoTaskFactory.inl src/factory/Workflow.h src/factory/WFOperator.h src/factory/WFResourcePool.h src/factory/WFMessageQueue.h src/factory/RedisTaskImpl.inl src/nameservice/WFNameService.h src/nameservice/WFDnsResolver.h src/nameservice/WFServiceGovernance.h src/nameservice/UpstreamPolicies.h ) if(KAFKA STREQUAL "y") set(INCLUDE_HEADERS ${INCLUDE_HEADERS} src/util/crc32c.h src/protocol/KafkaMessage.h src/protocol/KafkaDataTypes.h src/protocol/KafkaResult.h src/protocol/kafka_parser.h src/client/WFKafkaClient.h src/factory/KafkaTaskImpl.inl ) endif() workflow-0.11.8/CODE_OF_CONDUCT.md000066400000000000000000000062541476003635400162710ustar00rootroot00000000000000# Contributor Covenant Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others’ private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at xiehan@sogou-inc.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project’s leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] [homepage]: http://contributor-covenant.org [version]: http://contributor-covenant.org/version/1/4/ workflow-0.11.8/GNUmakefile000066400000000000000000000030041476003635400155320ustar00rootroot00000000000000ROOT_DIR := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) ALL_TARGETS := all base check install preinstall clean tutorial MAKE_FILE := Makefile DEFAULT_BUILD_DIR := build.cmake BUILD_DIR := $(shell if [ -f $(MAKE_FILE) ]; then echo "."; else echo $(DEFAULT_BUILD_DIR); fi) CMAKE3 := $(shell if which cmake3>/dev/null ; then echo cmake3; else echo cmake; fi;) .PHONY: $(ALL_TARGETS) all: base make -C $(BUILD_DIR) -f Makefile base: mkdir -p $(BUILD_DIR) ifeq ($(DEBUG),y) cd $(BUILD_DIR) && $(CMAKE3) -D CMAKE_BUILD_TYPE=Debug -D CONSUL=$(CONSUL) -D KAFKA=$(KAFKA) -D MYSQL=$(MYSQL) -D REDIS=$(REDIS) -D UPSTREAM=$(UPSTREAM) $(ROOT_DIR) else ifneq ("${INSTALL_PREFIX}install_prefix", "install_prefix") cd $(BUILD_DIR) && $(CMAKE3) -DCMAKE_INSTALL_PREFIX:STRING=${INSTALL_PREFIX} -D CONSUL=$(CONSUL) -D KAFKA=$(KAFKA) -D MYSQL=$(MYSQL) -D REDIS=$(REDIS) -D UPSTREAM=$(UPSTREAM) $(ROOT_DIR) else cd $(BUILD_DIR) && $(CMAKE3) -D CONSUL=$(CONSUL) -D KAFKA=$(KAFKA) -D MYSQL=$(MYSQL) -D REDIS=$(REDIS) -D UPSTREAM=$(UPSTREAM) $(ROOT_DIR) endif tutorial: all make -C tutorial check: all make -C test check install preinstall: base mkdir -p $(BUILD_DIR) cd $(BUILD_DIR) && $(CMAKE3) $(ROOT_DIR) make -C $(BUILD_DIR) -f Makefile $@ clean: -make -C test clean -make -C tutorial clean rm -rf $(DEFAULT_BUILD_DIR) rm -rf _include rm -rf _lib find . -name CMakeCache.txt | xargs rm -f find . -name Makefile | xargs rm -f find . -name "*.cmake" | xargs rm -f find . -name CMakeFiles | xargs rm -rf workflow-0.11.8/LICENSE000066400000000000000000000261421476003635400144750ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2020 Sogou Inc. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. workflow-0.11.8/LICENSE_GPLV2000066400000000000000000000432541476003635400154120ustar00rootroot00000000000000 GNU GENERAL PUBLIC LICENSE Version 2, June 1991 Copyright (C) 1989, 1991 Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Preamble The licenses for most software are designed to take away your freedom to share and change it. By contrast, the GNU General Public License is intended to guarantee your freedom to share and change free software--to make sure the software is free for all its users. This General Public License applies to most of the Free Software Foundation's software and to any other program whose authors commit to using it. (Some other Free Software Foundation software is covered by the GNU Lesser General Public License instead.) You can apply it to your programs, too. When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for this service if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs; and that you know you can do these things. To protect your rights, we need to make restrictions that forbid anyone to deny you these rights or to ask you to surrender the rights. These restrictions translate to certain responsibilities for you if you distribute copies of the software, or if you modify it. For example, if you distribute copies of such a program, whether gratis or for a fee, you must give the recipients all the rights that you have. You must make sure that they, too, receive or can get the source code. And you must show them these terms so they know their rights. We protect your rights with two steps: (1) copyright the software, and (2) offer you this license which gives you legal permission to copy, distribute and/or modify the software. Also, for each author's protection and ours, we want to make certain that everyone understands that there is no warranty for this free software. If the software is modified by someone else and passed on, we want its recipients to know that what they have is not the original, so that any problems introduced by others will not reflect on the original authors' reputations. Finally, any free program is threatened constantly by software patents. We wish to avoid the danger that redistributors of a free program will individually obtain patent licenses, in effect making the program proprietary. To prevent this, we have made it clear that any patent must be licensed for everyone's free use or not licensed at all. The precise terms and conditions for copying, distribution and modification follow. GNU GENERAL PUBLIC LICENSE TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 0. This License applies to any program or other work which contains a notice placed by the copyright holder saying it may be distributed under the terms of this General Public License. The "Program", below, refers to any such program or work, and a "work based on the Program" means either the Program or any derivative work under copyright law: that is to say, a work containing the Program or a portion of it, either verbatim or with modifications and/or translated into another language. (Hereinafter, translation is included without limitation in the term "modification".) Each licensee is addressed as "you". Activities other than copying, distribution and modification are not covered by this License; they are outside its scope. The act of running the Program is not restricted, and the output from the Program is covered only if its contents constitute a work based on the Program (independent of having been made by running the Program). Whether that is true depends on what the Program does. 1. You may copy and distribute verbatim copies of the Program's source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice and disclaimer of warranty; keep intact all the notices that refer to this License and to the absence of any warranty; and give any other recipients of the Program a copy of this License along with the Program. You may charge a fee for the physical act of transferring a copy, and you may at your option offer warranty protection in exchange for a fee. 2. You may modify your copy or copies of the Program or any portion of it, thus forming a work based on the Program, and copy and distribute such modifications or work under the terms of Section 1 above, provided that you also meet all of these conditions: a) You must cause the modified files to carry prominent notices stating that you changed the files and the date of any change. b) You must cause any work that you distribute or publish, that in whole or in part contains or is derived from the Program or any part thereof, to be licensed as a whole at no charge to all third parties under the terms of this License. c) If the modified program normally reads commands interactively when run, you must cause it, when started running for such interactive use in the most ordinary way, to print or display an announcement including an appropriate copyright notice and a notice that there is no warranty (or else, saying that you provide a warranty) and that users may redistribute the program under these conditions, and telling the user how to view a copy of this License. (Exception: if the Program itself is interactive but does not normally print such an announcement, your work based on the Program is not required to print an announcement.) These requirements apply to the modified work as a whole. If identifiable sections of that work are not derived from the Program, and can be reasonably considered independent and separate works in themselves, then this License, and its terms, do not apply to those sections when you distribute them as separate works. But when you distribute the same sections as part of a whole which is a work based on the Program, the distribution of the whole must be on the terms of this License, whose permissions for other licensees extend to the entire whole, and thus to each and every part regardless of who wrote it. Thus, it is not the intent of this section to claim rights or contest your rights to work written entirely by you; rather, the intent is to exercise the right to control the distribution of derivative or collective works based on the Program. In addition, mere aggregation of another work not based on the Program with the Program (or with a work based on the Program) on a volume of a storage or distribution medium does not bring the other work under the scope of this License. 3. You may copy and distribute the Program (or a work based on it, under Section 2) in object code or executable form under the terms of Sections 1 and 2 above provided that you also do one of the following: a) Accompany it with the complete corresponding machine-readable source code, which must be distributed under the terms of Sections 1 and 2 above on a medium customarily used for software interchange; or, b) Accompany it with a written offer, valid for at least three years, to give any third party, for a charge no more than your cost of physically performing source distribution, a complete machine-readable copy of the corresponding source code, to be distributed under the terms of Sections 1 and 2 above on a medium customarily used for software interchange; or, c) Accompany it with the information you received as to the offer to distribute corresponding source code. (This alternative is allowed only for noncommercial distribution and only if you received the program in object code or executable form with such an offer, in accord with Subsection b above.) The source code for a work means the preferred form of the work for making modifications to it. For an executable work, complete source code means all the source code for all modules it contains, plus any associated interface definition files, plus the scripts used to control compilation and installation of the executable. However, as a special exception, the source code distributed need not include anything that is normally distributed (in either source or binary form) with the major components (compiler, kernel, and so on) of the operating system on which the executable runs, unless that component itself accompanies the executable. If distribution of executable or object code is made by offering access to copy from a designated place, then offering equivalent access to copy the source code from the same place counts as distribution of the source code, even though third parties are not compelled to copy the source along with the object code. 4. You may not copy, modify, sublicense, or distribute the Program except as expressly provided under this License. Any attempt otherwise to copy, modify, sublicense or distribute the Program is void, and will automatically terminate your rights under this License. However, parties who have received copies, or rights, from you under this License will not have their licenses terminated so long as such parties remain in full compliance. 5. You are not required to accept this License, since you have not signed it. However, nothing else grants you permission to modify or distribute the Program or its derivative works. These actions are prohibited by law if you do not accept this License. Therefore, by modifying or distributing the Program (or any work based on the Program), you indicate your acceptance of this License to do so, and all its terms and conditions for copying, distributing or modifying the Program or works based on it. 6. Each time you redistribute the Program (or any work based on the Program), the recipient automatically receives a license from the original licensor to copy, distribute or modify the Program subject to these terms and conditions. You may not impose any further restrictions on the recipients' exercise of the rights granted herein. You are not responsible for enforcing compliance by third parties to this License. 7. If, as a consequence of a court judgment or allegation of patent infringement or for any other reason (not limited to patent issues), conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot distribute so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not distribute the Program at all. For example, if a patent license would not permit royalty-free redistribution of the Program by all those who receive copies directly or indirectly through you, then the only way you could satisfy both it and this License would be to refrain entirely from distribution of the Program. If any portion of this section is held invalid or unenforceable under any particular circumstance, the balance of the section is intended to apply and the section as a whole is intended to apply in other circumstances. It is not the purpose of this section to induce you to infringe any patents or other property right claims or to contest validity of any such claims; this section has the sole purpose of protecting the integrity of the free software distribution system, which is implemented by public license practices. Many people have made generous contributions to the wide range of software distributed through that system in reliance on consistent application of that system; it is up to the author/donor to decide if he or she is willing to distribute software through any other system and a licensee cannot impose that choice. This section is intended to make thoroughly clear what is believed to be a consequence of the rest of this License. 8. If the distribution and/or use of the Program is restricted in certain countries either by patents or by copyrighted interfaces, the original copyright holder who places the Program under this License may add an explicit geographical distribution limitation excluding those countries, so that distribution is permitted only in or among countries not thus excluded. In such case, this License incorporates the limitation as if written in the body of this License. 9. The Free Software Foundation may publish revised and/or new versions of the General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Program specifies a version number of this License which applies to it and "any later version", you have the option of following the terms and conditions either of that version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of this License, you may choose any version ever published by the Free Software Foundation. 10. If you wish to incorporate parts of the Program into other free programs whose distribution conditions are different, write to the author to ask for permission. For software which is copyrighted by the Free Software Foundation, write to the Free Software Foundation; we sometimes make exceptions for this. Our decision will be guided by the two goals of preserving the free status of all derivatives of our free software and of promoting the sharing and reuse of software generally. NO WARRANTY 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. END OF TERMS AND CONDITIONS How to Apply These Terms to Your New Programs If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively convey the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. Copyright (C) This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. Also add information on how to contact you by electronic and paper mail. If the program is interactive, make it output a short notice like this when it starts in an interactive mode: Gnomovision version 69, Copyright (C) year name of author Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. This is free software, and you are welcome to redistribute it under certain conditions; type `show c' for details. The hypothetical commands `show w' and `show c' should show the appropriate parts of the General Public License. Of course, the commands you use may be called something other than `show w' and `show c'; they could even be mouse-clicks or menu items--whatever suits your program. You should also get your employer (if you work as a programmer) or your school, if any, to sign a "copyright disclaimer" for the program, if necessary. Here is a sample; alter the names: Yoyodyne, Inc., hereby disclaims all copyright interest in the program `Gnomovision' (which makes passes at compilers) written by James Hacker. , 1 April 1989 Ty Coon, President of Vice This General Public License does not permit incorporating your program into proprietary programs. If your program is a subroutine library, you may consider it more useful to permit linking proprietary applications with the library. If this is what you want to do, use the GNU Lesser General Public License instead of this License. workflow-0.11.8/README.md000066400000000000000000000252121476003635400147440ustar00rootroot00000000000000[简体中文版(推荐)](README_cn.md) ## Sogou C++ Workflow [![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)](https://github.com/sogou/workflow/blob/master/LICENSE) [![Language](https://img.shields.io/badge/language-c++-red.svg)](https://en.cppreference.com/) [![Platform](https://img.shields.io/badge/platform-linux%20%7C%20macos%20%7C%20windows-lightgrey.svg)](https://img.shields.io/badge/platform-linux%20%7C%20macos20%7C%20windows-lightgrey.svg) [![Build Status](https://img.shields.io/github/actions/workflow/status/sogou/workflow/ci.yml?branch=master)](https://github.com/sogou/workflow/actions?query=workflow%3A%22ci+build%22++) As **Sogou\`s C++ server engine**, Sogou C++ Workflow supports almost all **back-end C++ online services** of Sogou, including all search services, cloud input method, online advertisements, etc., handling more than **10 billion** requests every day. This is an **enterprise-level programming engine** in light and elegant design which can satisfy most C++ back-end development requirements. #### You can use it: * To quickly build an **HTTP server**: ~~~cpp #include #include "workflow/WFHttpServer.h" int main() { WFHttpServer server([](WFHttpTask *task) { task->get_resp()->append_output_body("Hello World!"); }); if (server.start(8888) == 0) { // start server on port 8888 getchar(); // press "Enter" to end. server.stop(); } return 0; } ~~~ * As a **multifunctional asynchronous client**, it currently supports `HTTP`, `Redis`, `MySQL` and `Kafka` protocols. * ``MySQL`` protocol supports ``MariaDB``, ``TiDB`` as well. * To implement **client/server on user-defined protocol** and build your own **RPC system**. * [srpc](https://github.com/sogou/srpc) is based on it and it is an independent open source project, which supports srpc, brpc, trpc and thrift protocols. * To build **asynchronous workflow**; support common **series** and **parallel** structures, and also support any **DAG** structures. * As a **parallel computing tool**. In addition to **networking tasks**, Sogou C++ Workflow also includes **the scheduling of computing tasks**. All types of tasks can be put into **the same** flow. * As an **asynchronous file IO tool** in `Linux` system, with high performance exceeding any system call. Disk file IO is also a task. * To realize any **high-performance** and **high-concurrency** back-end service with a very complex relationship between computing and networking. * To build a **micro service** system. * This project has built-in **service governance** and **load balancing** features. * Wiki link : [PaaS Architecture](https://github.com/sogou/workflow/wiki) #### Compiling and running environment * This project supports `Linux`, `macOS`, `Windows`, `Android` and other operating systems. * `Windows` version is currently released as an independent [branch](https://github.com/sogou/workflow/tree/windows), using `iocp` to implement asynchronous networking. All user interfaces are consistent with the `Linux` version. * Supports all CPU platforms, including 32 or 64-bit `x86` processors, big-endian or little-endian `arm` processors, `loongson` processors. * Master branch requires SSL and `OpenSSL 1.1` or above is recommended. Fully compatible with BoringSSL. If you don't like SSL, you may checkout the [nossl](https://github.com/sogou/workflow/tree/nossl) branch. * Uses the `C++11` standard and therefore, it should be compiled with a compiler which supports `C++11`. Does not rely on `boost` or `asio`. * No other dependencies. However, if you need `Kafka` protocol, some compression libraries should be installed, including `lz4`, `zstd` and `snappy`. ### Get started (Linux, macOS): ~~~sh git clone https://github.com/sogou/workflow cd workflow make cd tutorial make ~~~~ #### With SRPC Tool (NEW!): https://github.com/sogou/srpc/blob/master/tools/README.md #### With [apt-get](https://launchpad.net/ubuntu/+source/workflow) on Debian Linux, ubuntu: Sogou C++ Workflow has been packaged for Debian Linux and ubuntu 22.04. To install the Workflow library for development purposes: ~~~~sh sudo apt-get install libworkflow-dev ~~~~ To install the Workflow library for deployment: ~~~~sh sudo apt-get install libworkflow1 ~~~~ #### With [dnf](https://packages.fedoraproject.org/pkgs/workflow) on Fedora Linux: Sogou C++ Workflow has been packaged for Fedora Linux. To install the Workflow library for development purposes: ~~~~sh sudo dnf install workflow-devel ~~~~ To install the Workflow library for deployment: ~~~~sh sudo dnf install workflow ~~~~ #### With xmake If you want to use xmake to build workflow, you can see [xmake build document](docs/en/xmake.md) # Tutorials * Client * [Creating your first task:wget](docs/en/tutorial-01-wget.md) * [Implementing Redis set and get:redis\_cli](docs/en/tutorial-02-redis_cli.md) * [More features about series:wget\_to\_redis](docs/en/tutorial-03-wget_to_redis.md) * Server * [First server:http\_echo\_server](docs/en/tutorial-04-http_echo_server.md) * [Asynchronous server:http\_proxy](docs/en/tutorial-05-http_proxy.md) * Parallel tasks and Series  * [A simple parallel wget:parallel\_wget](docs/en/tutorial-06-parallel_wget.md) * Important topics * [About error](docs/en/about-error.md) * [About timeout](docs/en/about-timeout.md) * [About global configuration](docs/en/about-config.md) * [About DNS](docs/en/about-dns.md) * [About exit](docs/en/about-exit.md) * Computing tasks * [Using the build-in algorithm factory:sort\_task](docs/en/tutorial-07-sort_task.md) * [User-defined computing task:matrix\_multiply](docs/en/tutorial-08-matrix_multiply.md) * [Use computing task in a simple way: go task](docs/en/about-go-task.md) * Asynchronous File IO tasks * [Http server with file IO:http\_file\_server](docs/en/tutorial-09-http_file_server.md) * User-defined protocol * [A simple user-defined protocol: client/server](docs/en/tutorial-10-user_defined_protocol.md) * [Use TLV message](docs/en/about-tlv-message.md) * Other important tasks/components * [About timer](docs/en/about-timer.md) * [About counter](docs/en/about-counter.md) * [About resource pool](docs/en/about-resource-pool.md) * [About module](docs/en/about-module.md) * [About DAG](docs/en/tutorial-11-graph_task.md) * Service governance * [About service governance](docs/en/about-service-governance.md) * [More documents about upstream](docs/en/about-upstream.md) * Connection context * [About connection context](docs/en/about-connection-context.md) * Built-in clients * [Asynchronous MySQL client:mysql\_cli](docs/en/tutorial-12-mysql_cli.md) * [Asynchronous Kafka client: kafka\_cli](docs/en/tutorial-13-kafka_cli.md) #### Programming paradigm We believe that a typical back-end program=protocol+algorithm+workflow and should be developed completely independently. * Protocol * In most cases, users use built-in common network protocols, such as HTTP, Redis or various rpc. * Users can also easily customize user-defined network protocol. In the customization, they only need to provide serialization and deserialization functions to define their own client/server. * Algorithm * In our design, the algorithm is a concept symmetrical to the protocol. * If protocol call is rpc, then algorithm call is an apc (Async Procedure Call). * We have provided some general algorithms, such as sort, merge, psort, reduce, which can be used directly. * Compared with a user-defined protocol, a user-defined algorithm is much more common. Any complicated computation with clear boundaries should be packaged into an algorithm. * Workflow * Workflow is the actual business logic, which is to put the protocols and algorithms into the flow graph for use. * The typical workflow is a closed series-parallel graph. Complex business logic may be a non-closed DAG. * The workflow graph can be constructed directly or dynamically generated based on the results of each step. All tasks are executed asynchronously. Basic task, task factory and complex task * Our system contains six basic tasks: networking, file IO, CPU, GPU, timer, and counter. * All tasks are generated by the task factory and automatically recycled after callback. * Server task is one kind of special networking task, generated by the framework which calls the task factory, and handed over to the user through the process function. * In most cases, the task generated by the user through the task factory is a complex task, which is transparent to the user. * For example, an HTTP request may include many asynchronous processes (DNS, redirection), but for user, it is just a networking task. * File sorting seems to be an algorithm, but it actually includes many complex interaction processes between file IO and CPU computation. * If you think of business logic as building circuits with well-designed electronic components, then each electronic component may be a complex circuit. Asynchrony and encapsulation based on `C++11 std::function` * Not based on user mode coroutines. Users need to know that they are writing asynchronous programs. * All calls are executed asynchronously, and there is almost no operation that occupies a thread. * Although we also provide some facilities with semi-synchronous interfaces, they are not core features. * We try to avoid user's derivations, and encapsulate user behavior with `std::function` instead, including: * The callback of any task. * Any server's process. This conforms to the `FaaS` (Function as a Service) idea. * The realization of an algorithm is simply a `std::function`. But the algorithm can also be implemented by derivation. Memory reclamation mechanism * Every task will be automatically reclaimed after the callback. If a task is created but a user does not want to run it, the user needs to release it through the dismiss method. * Any data in the task, such as the response of the network request, will also be recycled with the task. At this time, the user can use `std::move()` to move the required data. * SeriesWork and ParallelWork are two kinds of framework objects, which are also recycled after their callback. * When a series is a branch of a parallel, it will be recycled after the callback of the parallel that it belongs to. * This project doesn’t use `std::shared_ptr` to manage memory. #### Any other questions? You may check the [FAQ](https://github.com/sogou/workflow/issues/406) and [issues](https://github.com/sogou/workflow/issues) list first to see if you can find the answer. You are very welcome to send the problems you encounter in use to [issues](https://github.com/sogou/workflow/issues), and we will answer them as soon as possible. At the same time, more issues will also help new users. workflow-0.11.8/README_cn.md000066400000000000000000000255341476003635400154330ustar00rootroot00000000000000[English version](README.md) ## Sogou C++ Workflow [![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)](https://github.com/sogou/workflow/blob/master/LICENSE) [![Language](https://img.shields.io/badge/language-c++-red.svg)](https://en.cppreference.com/) [![Platform](https://img.shields.io/badge/platform-linux%20%7C%20macos%20%7C%20windows-lightgrey.svg)](https://img.shields.io/badge/platform-linux%20%7C%20macos20%7C%20windows-lightgrey.svg) [![Build Status](https://img.shields.io/github/actions/workflow/status/sogou/workflow/ci.yml?branch=master)](https://github.com/sogou/workflow/actions?query=workflow%3A%22ci+build%22++) 搜狗公司C++服务器引擎,编程范式。支撑搜狗几乎所有后端C++在线服务,包括所有搜索服务,云输入法,在线广告等,每日处理数百亿请求。这是一个设计轻盈优雅的企业级程序引擎,可以满足大多数后端与嵌入式开发需求。 #### 你可以用来: * 快速搭建http服务器: ~~~cpp #include #include "workflow/WFHttpServer.h" int main() { WFHttpServer server([](WFHttpTask *task) { task->get_resp()->append_output_body("Hello World!"); }); if (server.start(8888) == 0) { // start server on port 8888 getchar(); // press "Enter" to end. server.stop(); } return 0; } ~~~ * 作为万能异步客户端。目前支持``http``,``redis``,``mysql``和``kafka``协议。 * 轻松构建效率极高的spider。 * ``mysql``协议同时也支持``MariaDB``和``TiDB``等数据库。 * 实现自定义协议client/server,构建自己的RPC系统。 * [srpc](https://github.com/sogou/srpc)就是以它为基础,作为独立项目开源。支持``srpc``,``brpc``,``trpc``和``thrift``等协议。 * 构建异步任务流,支持常用的串并联,也支持更加复杂的DAG结构。 * 作为并行计算工具使用。除了网络任务,我们也包含计算任务的调度。所有类型的任务都可以放入同一个流中。 * 在``Linux``系统下作为文件异步IO工具使用,性能超过任何标准调用。磁盘IO也是一种任务。 * 实现任何计算与通讯关系非常复杂的高性能高并发的后端服务。 * 构建微服务系统。 * 项目内置服务治理与负载均衡等功能。 * Wiki链接 : [PaaS 架构图](https://github.com/sogou/workflow/wiki) #### 编译和运行环境 * 项目支持``Linux``,``macOS``,``Windows``,``Android``等操作系统。 * ``Windows``版以[windows](https://github.com/sogou/workflow/tree/windows)分支发布,使用``iocp``实现异步网络。用户接口与``Linux``版一致。 * 支持所有CPU平台,包括32或64位``x86``处理器,大端或小端``arm``处理器,国产``loongson``龙芯处理器实测支持。 * 需要依赖于``OpenSSL``,推荐``OpenSSL 1.1``及以上版本。 * 不喜欢SSL的用户可以使用[nossl](https://github.com/sogou/workflow/tree/nossl)分支,代码更简洁。 * 项目使用了``C++11``标准,需要用支持``C++11``的编译器编译。但不依赖``boost``或``asio``。 * 项目无其它依赖。如需使用``kafka``协议,需自行安装``lz4``,``zstd``和``snappy``几个压缩库。 #### 快速开始(Linux, macOS): ~~~sh git clone https://github.com/sogou/workflow # From gitee: git clone https://gitee.com/sogou/workflow cd workflow make cd tutorial make ~~~ #### 使用SRPC工具(NEW!) SRPC工具可以生成完整的workflow工程,根据用户命令生成对应的server,client或proxy框架,以及CMake工程文件和JSON格式的配置文件。 并且,工具会下载最小的必要的依赖。例如在用户指定产生RPC项目时,自动下载并配置好protobuf等依赖。 SRPC工具的使用方法可以参考:https://github.com/sogou/srpc/blob/master/tools/README_cn.md #### Debian Linux或ubuntu上使用[apt-get](https://launchpad.net/ubuntu/+source/workflow)安装: 作为是Debian Linux与Ubuntu Linux 22.04版自带软件,可以通过``apt-get``命令直接安装开发包: ~~~sh sudo apt-get install libworkflow-dev ~~~ 或部署运行环境: ~~~sh sudo apt-get install workflow1 ~~~ 注意ubuntu只有最新22.04版或以上自带workflow。更推荐用git直接下载最新源代码编译。 #### Fedora Linux上使用[dnf](https://packages.fedoraproject.org/pkgs/workflow)安装: Workflow也是Fedora Linux的自带软件,可以使用最新的rpm包管理工具``dnf``直接安装开发包: ~~~~sh sudo dnf install workflow-devel ~~~~ 或部署运行环境: ~~~~sh sudo dnf install workflow ~~~~ #### 使用xmake 如果你想用xmake去构建 workflow, 你可以看 [xmake build document](docs/xmake.md) # 示例教程 * Client基础 * [创建第一个任务:wget](docs/tutorial-01-wget.md) * [实现一次redis写入与读出:redis_cli](docs/tutorial-02-redis_cli.md) * [任务序列的更多功能:wget_to_redis](docs/tutorial-03-wget_to_redis.md) * Server基础 * [第一个server:http_echo_server](docs/tutorial-04-http_echo_server.md) * [异步server的示例:http_proxy](docs/tutorial-05-http_proxy.md) * 并行任务与工作流  * [一个简单的并行抓取:parallel_wget](docs/tutorial-06-parallel_wget.md) * 几个重要的话题 * [关于错误处理](docs/about-error.md) * [关于超时](docs/about-timeout.md) * [关于全局配置](docs/about-config.md) * [关于DNS](docs/about-dns.md) * [关于程序退出](docs/about-exit.md) * 计算任务 * [使用内置算法工厂:sort_task](docs/tutorial-07-sort_task.md) * [自定义计算任务:matrix_multiply](docs/tutorial-08-matrix_multiply.md) * [更加简单的使用计算任务:go_task](docs/about-go-task.md)【推荐】 * 文件异步IO任务 * [异步IO的http server:http_file_server](docs/tutorial-09-http_file_server.md) * 用户定义协议基础 * [简单的用户自定义协议client/server](docs/tutorial-10-user_defined_protocol.md) * [使用TLV格式消息](docs/about-tlv-message.md) * 其它一些重要任务与组件 * [关于定时器](docs/about-timer.md) * [关于计数器](docs/about-counter.md) * [模块任务](docs/about-module.md) * [DAG图任务](docs/tutorial-11-graph_task.md) * [Selector任务](docs/about-selector.md) * 任务间通信 * [条件任务与观察者模式](docs/about-conditional.md) * [资源池与消息队列](docs/about-resource-pool.md) * 服务治理 * [关于服务治理](docs/about-service-governance.md) * [Upstream更多文档](docs/about-upstream.md) * [自定义名称服务策略](docs/tutorial-15-name_service.md) * 连接上下文的使用 * [关于连接上下文](docs/about-connection-context.md) * 内置客户端 * [异步MySQL客户端:mysql_cli](docs/tutorial-12-mysql_cli.md) * [异步kafka客户端:kafka_cli](docs/tutorial-13-kafka_cli.md) * [异步DNS客户端:dns_cli](docs/tutorial-17-dns_cli.md) * [Redis订阅客户端:redis_subscriber](docs/tutorial-18-redis_subscriber.md) #### 编程范式 程序 = 协议 + 算法 + 任务流 * 协议 * 大多数情况下,用户使用的是内置的通用网络协议,例如http,redis或各种rpc。 * 用户可以方便的自定义网络协议,只需提供序列化和反序列化函数,就可以定义出自己的client/server。 * 算法 * 在我们的设计里,算法是与协议对称的概念。 * 如果说协议的调用是rpc,算法的调用就是一次apc(Async Procedure Call)。 * 我们提供了一些通用算法,例如sort,merge,psort,reduce,可以直接使用。 * 与自定义协议相比,自定义算法的使用要常见得多。任何一次边界清晰的复杂计算,都应该包装成算法。 * 任务流 * 任务流就是实际的业务逻辑,就是把开发好的协议与算法放在流程图里使用起来。 * 典型的任务流是一个闭合的串并联图。复杂的业务逻辑,可能是一个非闭合的DAG。 * 任务流图可以直接构建,也可以根据每一步的结果动态生成。所有任务都是异步执行的。 结构化并发与任务隐藏 * 我们系统中包含五种基础任务:通讯,计算,文件IO,定时器,计数器。 * 一切任务都由任务工厂产生,用户通过调用接口组织并发结构。例如串联并联,DAG等。 * 大多数情况下,用户通过任务工厂产生的任务,都隐藏了多个异步过程,但用户并不感知。 * 例如,一次http请求,可能包含许多次异步过程(DNS,重定向),但对用户来讲,就是一次通信任务。 * 文件排序,看起来就是一个算法,但其实包括复杂的文件IO与CPU计算的交互过程。 * 如果把业务逻辑想象成用设计好的电子元件搭建电路,那么每个电子元件内部可能又是一个复杂电路。 * 任务隐藏机制大幅减少了用户需要创建的任务数量和回调深度。 * 任何任务都运行在某个串行流(series)里,共享series上下文,让异步任务之间数据传递变得简单。 回调与内存回收机制 * 一切调用都是异步执行,几乎不存在占着线程等待的操作。 * 显式的回调机制。用户清楚自己在写异步程序。 * **通过一套对象生命周期机制,大幅简化异步程序的内存管理** * 任何框架创建的任务,生命周期都是从创建到callback函数运行结束为止。没有泄漏风险。 * 如果创建了任务之后不想运行,则需要通过dismiss()接口删除。 * 任务中的数据,例如网络请求的resp,也会随着任务被回收。此时用户可通过``std::move()``把需要的数据移走。 * 项目中不使用任何智能指针来管理内存。代码观感清新。 * 尽量避免用户级别派生,以``std::function``封装用户行为,包括: * 任何任务的callback。 * 任何server的process。符合``FaaS``(Function as a Service)思想。 * 一个算法的实现,简单来讲也是一个``std::function``。 * 如果深入使用,又会发现一切皆可派生。 # 使用中有疑问? 可以先查看[FAQ](https://github.com/sogou/workflow/issues/170)和[issues](https://github.com/sogou/workflow/issues)列表,看看是否能找到答案。 非常欢迎将您使用中遇到的问题发送到[issues](https://github.com/sogou/workflow/issues),我们将第一时间进行解答。同时更多的issue对新用户也会带来帮助。 也可以通过QQ群:**618773193** 联系我们。 qq_qrcode #### Gitee仓库 用户可以在访问GitHub遇到困难时,使用我们的Gitee官方仓库:https://gitee.com/sogou/workflow **另外也麻烦在Gitee上star了项目的用户,尽量同步star一下[GitHub仓库](https://github.com/sogou/workflow)。谢谢!** workflow-0.11.8/WORKSPACE000066400000000000000000000000001476003635400147320ustar00rootroot00000000000000workflow-0.11.8/benchmark/000077500000000000000000000000001476003635400154155ustar00rootroot00000000000000workflow-0.11.8/benchmark/CMakeLists.txt000066400000000000000000000024121476003635400201540ustar00rootroot00000000000000cmake_minimum_required(VERSION 3.6) set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "build type") project(benchmark LANGUAGES C CXX ) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}) find_library(LIBRT rt) find_package(OpenSSL REQUIRED) find_package(workflow REQUIRED CONFIG HINTS ..) include_directories(${OPENSSL_INCLUDE_DIR} ${WORKFLOW_INCLUDE_DIR}) link_directories(${WORKFLOW_LIB_DIR}) if (WIN32) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /MP /wd4200") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /wd4200 /std:c++14") else () set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -fPIC -pipe -std=gnu90") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -fPIC -pipe -std=c++11 -fno-exceptions -Wno-invalid-offsetof") if (APPLE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated-declarations") endif() endif () set(BENCHMARK_LIST benchmark-01-http_server benchmark-02-http_server_long_req ) if (APPLE) set(WORKFLOW_LIB workflow pthread OpenSSL::SSL OpenSSL::Crypto) else () set(WORKFLOW_LIB workflow pthread OpenSSL::SSL OpenSSL::Crypto ${LIBRT}) endif () foreach(src ${BENCHMARK_LIST}) string(REPLACE "-" ";" arr ${src}) list(GET arr -1 bin_name) add_executable(${bin_name} ${src}.cc) target_link_libraries(${bin_name} ${WORKFLOW_LIB}) endforeach() workflow-0.11.8/benchmark/GNUmakefile000066400000000000000000000013111476003635400174630ustar00rootroot00000000000000ROOT_DIR := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) ALL_TARGETS := all clean MAKE_FILE := Makefile DEFAULT_BUILD_DIR := build BUILD_DIR := $(shell if [ -f $(MAKE_FILE) ]; then echo "."; else echo $(DEFAULT_BUILD_DIR); fi) CMAKE3 := $(shell if which cmake3>/dev/null ; then echo cmake3; else echo cmake; fi;) .PHONY: $(ALL_TARGETS) all: mkdir -p $(BUILD_DIR) ifeq ($(DEBUG),y) cd $(BUILD_DIR) && $(CMAKE3) -D CMAKE_BUILD_TYPE=Debug $(ROOT_DIR) else cd $(BUILD_DIR) && $(CMAKE3) $(ROOT_DIR) endif make -C $(BUILD_DIR) -f Makefile clean: ifeq ($(MAKE_FILE), $(wildcard $(MAKE_FILE))) -make -f Makefile clean else ifeq (build, $(wildcard build)) -make -C build clean endif rm -rf build workflow-0.11.8/benchmark/README.md000066400000000000000000000162641476003635400167050ustar00rootroot00000000000000# 性能测试 Sogou C++ Workflow是一款性能优异的网络框架,本文介绍我们进行的性能测试, 包括方案、代码、结果,以及与其他同类产品的对比。 更多场景下的实验正在进行中,本文将持续更新。 ## HTTP Server HTTP Client/Server是Sogou C++ Workflow常见的应用场景, 我们首先对Server端进行实验。 ### 环境 我们部署了两台相同机器作为Server和Client,软硬件配置如下: | 软硬件 | 配置 | |:---:|:---| | CPU | 40 Cores, x86_64, Intel(R) Xeon(R) CPU E5-2630 v4 @ 2.20GHz | | Memory | 192GB | | NIC | 25000Mbps | | OS | CentOS 7.8.2003 | | Kernel | Linux version 3.10.0-1127.el7.x86_64 | | GCC | 4.8.5 | 两者间`ping`测得的RTT为0.1ms左右。 ### 对照组 我们选择nginx和brpc作为对照组。 选择前者是因为它在生产中部署十分广泛,性能不俗; 对于后者,我们在本次实验中只关注HTTP Server方面的能力, 其他的特性已有[单独的实验][Sogou RPC Benchmark]进行更为详尽的测试。 事实上,我们也对此二者之外的其他某些框架同时进行了实验, 但结果其性能表现相差较远,因此未在本文中体现。 后续我们将选取更多合适的框架加入对比测试中。 ### Client工具 本次实验我们使用的压测工具为[wrk][wrk]和[wrk2][wrk2]。 前者适合测试特定并发下的QPS极限和延时, 后者适合在特定QPS下测试延时分布。 我们也尝试过使用其他测试工具,例如[ab][ab]等,但无法打出足够的压力。 有鉴于此,我们也在着手开发基于Sogou C++ Workflow的benchmark工具。 ### 变量和指标 一般而言,对网络框架的性能测试,切入的角度可谓纷繁多样。 通过控制不同的变量、观测不同的指标,可以探究程序在不同场景下的适应能力。 本次实验,我们选择其中最普遍常见的变量和指标: 通过控制Client并发度和承载数据的大小,来测试QPS和延时的变化情况。 另外,我们还测试了在掺杂慢请求的正常请求的延时分布。 下面依次介绍两个测试场景。 ### 测试方法 #### 启动http server 1. 编译benchmark 2. 进入到benchmark目录,执行 ``` ./http_server 12 9000 11 ``` 说明: 启动参数分别为线程数、端口和响应的随机字符串长度。 ### wrk测试 ``` wrk --latency -d10 -c200 --timeout 8 -t 6 http://127.0.0.1:9000 ``` **命令行解释** -c200: 启动200个连接 -t6: 开启6个线程做压力测试 -d10: 压测持续10s --timeout 8: 连接超时时间8s ### 不同并发度和数据长度下的QPS和延时 #### 代码和配置 我们搭建了一个极其简约的HTTP服务器, 忽略掉所有的业务逻辑, 将测试点聚焦在纯粹的网络框架性能上。 代码片段如下, 完整代码移步[这里][benchmark-01 Code]。 ```cpp // ... auto * resp = task->get_resp(); resp->add_header_pair("Date", timestamp); resp->add_header_pair("Content-Type", "text/plain; charset=UTF-8"); resp->append_output_body_nocopy(content.data(), content.size()); // ... ``` 可以从上述代码中看到, 对于到来的任何HTTP请求, 我们都会返回一段固定的内容作为Body, 并设置必要的Header, 包括代码中指明的`content-type`、`date`, 以及自动填充的`connection`和`content-length`。 HTTP Body的固定内容是在Server启动时随机生成的ASCII字符串, 其长度可以通过启动参数配置。 同时可以配置的还有使用的poller线程数和监听的端口号。 前者我们在本次测试中固定为16, 因此Sogou C++ Workflow将使用16个poller线程和20个handler线程(默认配置)。 对于nginx和brpc, 我们也构建了相同的返回内容, 并为nginx配置了40个进程、 brpc配置了40个线程。 #### 变量 我们控制并发度在`[1, 2K]`之间翻倍增长, 数据长度在`[16B, 64KB]`之间翻倍增长, 两者正交。 #### 指标 鉴于并发度和数据长度组合之后数量较多, 我们选择其中部分数据绘制为曲线。 ##### 固定数据长度下QPS与并发度关系 ![Concurrency and QPS][Con-QPS] 上图可以看出,当数据长度保持不变, QPS随着并发度提高而增大,后趋于平稳。 此过程中Sogou C++ Workflow一直有明显优势, 高于brpc和nginx。 特别是数据长度为64和512的两条曲线, 并发度足够的时候,可以保持500K的QPS。 注意上图中nginx-64与nginx-512的曲线重叠度很高, 不易辨识。 ##### 固定并发度下QPS与数据长度关系 ![Body Length and QPS][Len-QPS] 上图可以看出,当并发度保持不变, 随着数据长度的增长, QPS保持平稳至4K时下降。 此过程中,Sogou C++ Workflow也一直保持优势。 ##### 固定数据长度下延时与并发度关系 ![Concurrency and Latency][Con-Lat] 上图可以看出,保持数据长度不变, 延时随并发度提高而有所上升。 此过程中,Sogou C++ Workflow略好于brpc, 大好于nginx。 ##### 固定并发度下延时与数据长度关系 ![Body Length and Latency][Len-Lat] 上图可以看出,并发度保持不变时, 增大数据长度,造成延时上升。 此过程中,Sogou C++ Workflow好于nginx, 好于brpc。 ### 掺杂慢请求的延时分布 #### 代码 我们在上一个测试的基础上,简单添加了一个慢请求的逻辑, 模拟业务场景中可能出现的特殊情况。 代码片段如下, 完整代码请移步[这里][benchmark-02 Code]。 ```cpp // ... if (std::strcmp(uri, "/long_req/") == 0) { auto timer_task = WFTaskFactory::create_timer_task(microseconds, nullptr); series_of(task)->push_back(timer_task); } // ... ``` 我们在Server的process里进行判断, 如果访问的是特定的路径, 则添加一个`WFTimerTask`到Series的末尾, 能够模拟一个异步耗时处理过程。 类似地,对brpc使用`bthread_usleep()`函数进行异步睡眠。 #### 配置 在本次实验中,我们固定并发度为1024,数据长度为1024字节, 分别以QPS为20K、100K和200K进行正常请求测试, 测绘延时; 与此同时,有另一路压力,进行慢请求, QPS是上述QPS的1%, 数据不计入统计。 慢请求的时长固定为5ms。 #### 延时CDF图 ![Latency CDF][Lat CDF] 从上图可以看出,当QPS为20K时, Sogou C++ Workflow略次于brpc; 当QPS为100K时,两者几乎相当; 当QPS为200K时,Sogou C++ Workflow略好于brpc。 总之,可以认为两者在这方面旗鼓相当。 [Sogou RPC Benchmark]: https://github.com/holmes1412/sogou-rpc-benchmark [wrk]: https://github.com/wg/wrk [wrk2]: https://github.com/giltene/wrk2 [ab]: https://httpd.apache.org/docs/2.4/programs/ab.html [benchmark-01 Code]: benchmark-01-http_server.cc [benchmark-02 Code]: benchmark-02-http_server_long_req.cc [Con-QPS]: https://raw.githubusercontent.com/wiki/sogou/workflow/img/benchmark-01.png [Len-QPS]: https://raw.githubusercontent.com/wiki/sogou/workflow/img/benchmark-02.png [Con-Lat]: https://raw.githubusercontent.com/wiki/sogou/workflow/img/benchmark-03.png [Len-Lat]: https://raw.githubusercontent.com/wiki/sogou/workflow/img/benchmark-04.png [Lat CDF]: https://raw.githubusercontent.com/wiki/sogou/workflow/img/benchmark-05.png workflow-0.11.8/benchmark/benchmark-01-http_server.cc000066400000000000000000000021661476003635400224440ustar00rootroot00000000000000#include #include #include #include #include "util/args.h" #include "util/content.h" #include "util/date.h" static WFFacilities::WaitGroup wait_group{1}; void signal_handler(int) { wait_group.done(); } int main(int argc, char ** argv) { size_t pollers; unsigned short port; size_t length; if (parse_args(argc, argv, pollers, port, length) != 3) { return -1; } std::signal(SIGINT, signal_handler); std::signal(SIGTERM, signal_handler); WFGlobalSettings settings = GLOBAL_SETTINGS_DEFAULT; settings.poller_threads = pollers; WORKFLOW_library_init(&settings); const std::string content = make_content(length); WFHttpServer server([&content](WFHttpTask * task) { auto * resp = task->get_resp(); char timestamp[32]; date(timestamp, sizeof(timestamp)); resp->add_header_pair("Date", timestamp); resp->add_header_pair("Content-Type", "text/plain; charset=UTF-8"); resp->append_output_body_nocopy(content.data(), content.size()); }); if (server.start(port) == 0) { wait_group.wait(); server.stop(); } return 0; } workflow-0.11.8/benchmark/benchmark-02-http_server_long_req.cc000066400000000000000000000026521476003635400243330ustar00rootroot00000000000000#include #include #include #include #include #include "util/args.h" #include "util/content.h" #include "util/date.h" static WFFacilities::WaitGroup wait_group{1}; void signal_handler(int) { wait_group.done(); } int main(int argc, char ** argv) { size_t pollers; unsigned short port; size_t length; size_t microseconds; if (parse_args(argc, argv, pollers, port, length, microseconds) != 4) { return -1; } std::signal(SIGINT, signal_handler); std::signal(SIGTERM, signal_handler); WFGlobalSettings settings = GLOBAL_SETTINGS_DEFAULT; settings.poller_threads = pollers; WORKFLOW_library_init(&settings); const std::string content = make_content(length); WFHttpServer server([&content, µseconds](WFHttpTask * task) { auto resp = task->get_resp(); char timestamp[32]; date(timestamp, sizeof(timestamp)); resp->add_header_pair("Date", timestamp); resp->add_header_pair("Content-Type", "text/plain; charset=UTF-8"); resp->append_output_body_nocopy(content.data(), content.size()); auto req = task->get_req(); auto uri = req->get_request_uri(); if (std::strcmp(uri, "/long_req/") == 0) { auto timer_task = WFTaskFactory::create_timer_task(microseconds, nullptr); series_of(task)->push_back(timer_task); } }); if (server.start(port) == 0) { wait_group.wait(); server.stop(); } return 0; } workflow-0.11.8/benchmark/util/000077500000000000000000000000001476003635400163725ustar00rootroot00000000000000workflow-0.11.8/benchmark/util/args.h000066400000000000000000000035531476003635400175050ustar00rootroot00000000000000#ifndef _BENCHMARK_ARGS_H_ #define _BENCHMARK_ARGS_H_ #include #include #include namespace details { inline bool extract(const char * p, size_t & t) { char * e; long long ll = std::strtoll(p, &e, 0); if (*e || ll < 0) { return false; } t = static_cast(ll); return true; } inline bool extract(const char * p, unsigned short & t) { char * e; long long ll = std::strtoll(p, &e, 0); if (*e || ll < static_cast(std::numeric_limits::min()) || ll > static_cast(std::numeric_limits::max()) ) { return false; } t = static_cast(ll); return true; } inline bool extract(const char * p, std::string & t) { t = p; return true; } inline bool extract(const char * p, const char *& t) { t = p; return true; } template inline int parse_one(bool & flag, char **& p, char ** end, ARG & arg) { if (flag && (flag = p < end) && (flag = extract(*p, arg))) { p++; } return 0; } template inline size_t parse_all(char ** begin, char ** end, ARGS & ... args) { bool flag = true; char ** p = begin; static_cast(std::initializer_list{parse_one(flag, p, end, args) ...}); return p - begin; } template inline size_t parse_args(int & argc, char ** argv, ARGS & ... args) { if (argc <= 1) { return 0; } size_t length = argc - 1; char ** begin = argv + 1; char ** end = begin + length; size_t done = parse_all(begin, end, args ...); std::rotate(begin, begin + done, end); std::reverse(end - done, end); argc -= done; return done; } } template inline static size_t parse_args(int & argc, char ** argv, ARGS & ... args) { return details::parse_args(argc, argv, args ...); } #endif //_BENCHMARK_ARGS_H_ workflow-0.11.8/benchmark/util/content.h000066400000000000000000000006171476003635400202210ustar00rootroot00000000000000#ifndef _BENCHMARK_CONTENT_H_ #define _BENCHMARK_CONTENT_H_ #include #include static inline std::string make_content(size_t length) { std::mt19937_64 gen{42}; std::uniform_int_distribution dis{32, 126}; std::string s; s.reserve(length); for (size_t i = 0; i < length; i++) { s.push_back(static_cast(dis(gen))); } return s; } #endif //_BENCHMARK_CONTENT_H_ workflow-0.11.8/benchmark/util/date.h000066400000000000000000000004561476003635400174650ustar00rootroot00000000000000#ifndef _BENCHMARK_DATE_H_ #define _BENCHMARK_DATE_H_ #include static inline void date(char * buf, size_t n) { auto tt = std::time(nullptr); std::tm cur{}; // gmtime_r(&tt, &cur); localtime_r(&tt, &cur); strftime(buf, n, "%a, %d %b %Y %H:%M:%S %Z", &cur); } #endif //_BENCHMARK_DATE_H_ workflow-0.11.8/benchmark/xmake.lua000066400000000000000000000010571476003635400172300ustar00rootroot00000000000000set_group("benchmark") set_default(false) add_deps("workflow") if not is_plat("macosx") then add_ldflags("-lrt") end function all_benchs() local res = {} for _, x in ipairs(os.files("**.cc")) do local item = {} local s = path.filename(x) table.insert(item, s:sub(1, #s - 3)) -- target table.insert(item, path.relative(x, ".")) -- source table.insert(res, item) end return res end for _, bench in ipairs(all_benchs()) do target(bench[1]) set_kind("binary") add_files(bench[2]) end workflow-0.11.8/docs/000077500000000000000000000000001476003635400144135ustar00rootroot00000000000000workflow-0.11.8/docs/about-conditional.md000066400000000000000000000116511476003635400203540ustar00rootroot00000000000000# 条件任务与观察者模式 有的时候,我们需要让任务在某个条件下才被执行。条件任务(WFConditional)就是用于解决这种问题。 条件任务是一种任务包装器,可以包装任何的任务并取代原任务。通过对条件任务发送信号来触发被包装任务的执行。 # 条件任务的创建 在[WFTaskFactory.h](/src/factory/WFTaskFactory.h)里,可以看到条件任务的创建接口。 ~~~cpp class WFTaskFactory { public: static WFConditional *create_conditional(SubTask *task); static WFConditional *create_conditional(SubTask *task, void **msgbuf); }; ~~~ 可以看到,我们通过工厂的create_conditional接口创建条件任务。 其中,task为被包装的任务。msgbuf是用于接收消息的缓冲区,如果无需关注消息的具体内容,msgbuf可以缺省。 WFConditional的主要接口: ~~~cpp class WFConditional : public WFGenericTask { public: virtual void signal(void *msg); ... }; ~~~ WFConditional是一种任务,所以,它满足普通workflow任务的一切属性。特别的接口只有signal,用于发送信号。 # 示例 以下示例,通过timer和conditional,实现一个延迟1秒执行的计算任务。 ~~~cpp int main() { WFGoTask *task = WFTaskFactory::create_go_task("test", [](){ printf("Done\n"); }); WFConditional *cond = WFTaskFactory::create_conditional(task); WFTimerTask *timer = WFTaskFactory::create_timer_task(1, 0, [cond](void *){ cond->signal(NULL); }); timer->start(); cond->start(); getchar(); } ~~~ 这个示例里,在定时器的回调里向cond发送信号,让被包装的go task可以被执行。 注意,无论cond->signal()与cond->start()哪一个先被调用,程序都完全正确。 # 观察者模式 我们看到,如果直接对cond发送信息,需要发送者直接持有cond的指针,这在一些情况下并不是很方便。 于是,我们引入了观察者模式,也就是命名的条件任务。通过向某个名称发送信号,同时唤醒所有在这个名称下的条件任务。 命名条件任务的创建与唤醒: ~~~cpp class WFTaskFactory { public: static WFConditional *create_conditional(const std::string& cond_name, SubTask *task); static WFConditional *create_conditional(const std::string& cond_name, SubTask *task, void **msgbuf); static int signal_by_name(const std::string& cond_name, void *msg); static int signal_by_name(const std::string& cond_name, void *msg, size_t max); template static int signal_by_name(const std::string& cond_name, T *const msg[], size_t max); }; ~~~ 我们看到,与普通条件任务唯一区别是,命名条件任务创建时,需要传入一个cond_name。 而signal_by_name()接口,默认将msg发送到所有在这个名称上等待的条件任务,将它们全部唤醒。 也可以通过max参数指定唤醒的最大任务数。此时,msg还可以是一个指针数组,可给不同的条件任务发送不同的消息。 任何一个signal_by_name的重载函数,其返回值都是表示实际唤醒的条件任务个数。 这就相当于实现了观察者模式。 # 示例 还是上面的延迟计算示例,我们增加到两个计算任务并用观察者模式来实现。用"slot1"作为条件任务名。 ~~~cpp int main() { WFGoTask *task1 = WFTaskFactory::create_go_task("test", [](){ printf("test1 done\n"); }); WFGoTask *task2 = WFTaskFactory::create_go_task("test", [](){ printf("test2 done\n"); }); WFConditional *cond1 = WFTaskFactory::create_conditional("slot1", task1); WFConditional *cond2 = WFTaskFactory::create_conditional("slot1", task2); WFTimerTask *timer = WFTaskFactory::create_timer_task(1, 0, [](void *){ WFTaskFactory::signal_by_name("slot1", NULL); }); timer->start(); cond1->start(); cond2->start(); getchar(); } ~~~ 我们看到,在这个示例里,timer在回调中通过signal_by_name方法,同时唤醒了slot1下两个计算任务。 # 使用条件任务注意事项 Workflow里的任何任务,如果创建之后不想运行,都可以通过dismiss接口直接释放。 对于条件任务,如果要被dismiss(或者在某个被cancel的series里),必须保证这个条件任务没有被signal过。 以下代码的行为无定义: ~~~cpp int main() { WFEmptyTask *task = WFTaskFactory::create_empty_task(); WFConditional *cond = WFTaskFactory::create_conditional("slot1", task); WFTimerTask *timer = WFTaskFactory::create_timer_task(0, 0, [](void *) { WFTaskFactory::signal_by_name("slot1"); }); timer->start(); cond->dismiss(); // 取消任务 getchar(); } ~~~ 显然,如果timer的callback里已经执行或正在执行了signal_by_name,cond被signal,再dismiss()是一种错误行为。 这种情况一般也只会出现在命名条件任务里。所以,dismiss一个命名条件任务,需要特别的小心。 workflow-0.11.8/docs/about-config.md000066400000000000000000000072531476003635400173210ustar00rootroot00000000000000# 关于全局配置 全局配置用于配置全局默认参数,以适应的实际业务需求,提升程序性能。 全局配置的修改必须在使用框架任何调用之前,否则修改可能无法生效。 另外,一些全局配置选项,可以被upstream配置覆盖。这部分请参考upstream相关文档。 # 修改默认配置 在[WFGlobal.h](../src/manager/WFGlobal.h)里,包含了全局配置的结构体与默认值: ~~~cpp struct WFGlobalSettings { struct EndpointParams endpoint_params; struct EndpointParams dns_server_params; unsigned int dns_ttl_default; ///< in seconds, DNS TTL when network request success unsigned int dns_ttl_min; ///< in seconds, DNS TTL when network request fail int dns_threads; int poller_threads; int handler_threads; int compute_threads; ///< auto-set by system CPU number if value<=0 int fio_max_events; const char *resolv_conf_path; const char *hosts_path; }; static constexpr struct WFGlobalSettings GLOBAL_SETTINGS_DEFAULT = { .endpoint_params = ENDPOINT_PARAMS_DEFAULT, .dns_server_params = ENDPOINT_PARAMS_DEFAULT, .dns_ttl_default = 12 * 3600, .dns_ttl_min = 180, .dns_threads = 4, .poller_threads = 4, .handler_threads = 20, .compute_threads = -1, .fio_max_events = 4096, .resolv_conf_path = "/etc/resolv.conf", .hosts_path = "/etc/hosts", }; ~~~ 其中EndpointParams结构体和默认值在[EndpointParams.h](../src/manager/EndpointParams.h)文件里: ~~~cpp struct EndpointParams { size_t max_connections; int connect_timeout; int response_timeout; int ssl_connect_timeout; bool use_tls_sni; }; static constexpr struct EndpointParams ENDPOINT_PARAMS_DEFAULT = { .max_connections = 200, .connect_timeout = 10 * 1000, .response_timeout = 10 * 1000, .ssl_connect_timeout = 10 * 1000, .use_tls_sni = false, }; ~~~ 举个例子,把默认的连接超时改为5秒,dns默认ttl改为1小时,用于消息反序列化的poller线程增加到10个: ~~~cpp #include "workflow/WFGlobal.h" int main() { struct WFGlobalSettings settings = GLOBAL_SETTINGS_DEFAULT; settings.endpoint_params.connect_timeout = 5 * 1000; settings.dns_ttl_default = 3600; settings.poller_threads = 10; WORKFLOW_library_init(&settings); ... } ~~~ 大多数参数的意义都比较清晰。注意dns ttl相关参数,单位是**秒**。endpoint相关超时参数单位是**毫秒**,并且可以用-1表示无限。 dns_threads表示并行访问dns的线程数。但目前我们默认使用我们自己的异步DNS解析,所以并不会创建DNS线程(Window平台除外)。 dns_server_params表示是我们访问DNS server的参数,包括最大并发连接,以及连接与响应超时。 compute_threads表示用于计算的线程数,默认-1代表与当前节点CPU核数相同。 fio_max_events是异步文件IO的最大并发事件数。 resolv_conf_path是dns配置文件的路径,unix平台下默认为"/etc/resolv.conf"。Windows下默认为NULL,将使用多线程dns解析。 hosts_path是hosts文件路径。unix平台下默认为"/etc/hosts“。只有配置了resolv_conf_path,这个配置才起作用。 与网络性能相关的两个参数为poller_threads和handler_threads: * poller线程主要负责epoll(kqueue)和消息反序列化。 * handler线程是网络任务callback和process所在线程。 所有框架需要的资源,都是在第一次被使用时才申请的。例如用户没有用到dns解析,那么异步dns解析器或dns线程不会被启动。 workflow-0.11.8/docs/about-connection-context.md000066400000000000000000000137521476003635400216760ustar00rootroot00000000000000# 关于连接上下文 连接上下文是使用本框架编程的一个高级课题。使用上会有一些复杂性。 从之前的示例里可以看出,无论是client还是server任务,我们并没有手段指定使用的具体连接。 但是有一些业务场景,特别是server端,可能是需要维护连接状态的。也就是说我们需要把一段上下文和连接绑定。 我们的框架里,是提供了连接上下文机制给用户使用的。 # 连接上下文的应用场景 http协议可以说是一种完全无连接状态的协议,http会话,是通过cookie来实现的。这种协议对于我们的框架最友好。类似的还有kafka。 而redis和mysql的连接则是明显带状态,redis通过SELECT命令,指定当前连接上的数据库ID。mysql则是一个彻彻底底的有状态连接。 使用框架的redis或非事务mysql client任务时,由于URL里已经包含了所有和连接选择有关的信息,包括: * 用户名密码 * 数据库名或数据库号 * mysql的字符集 框架会根据这些信息自动登录和选择可复用的连接,用户无需关心连接上下文的问题。 这也是为什么,框架里redis的SELECT命令和mysql的USE命令是禁止用户使用的,切换数据库需要用新的URL创建任务。 事务型mysql,可以固定连接,这部分内容请参考mysql相关文档。 但是,如果我们实现一个redis协议的server,那我们需要知道当前连接上的状态了。 此外,我们还可以通过连接上下文件被释放的事件来感知连接被远端关闭。 # 使用连接上下文的方法 我们需要强调的是,一般情况下只有server任务需要使用连接上下文,并且只需要在process函数内部使用,这也是最安全最简单的用法。 但是,任务在callback里也可以使用或修改连接上下文,只是使用的时候需要考虑并发的问题。我们会详细地讨论相关问题。 任何网络任务都可以调用接口获得连接对象,进而获得或修改连接上下文。在[WFTask.h](../src/factory/WFTask.h)里,调用如下: ~~~cpp template class WFNetworkTask : public CommRequest { public: virtual WFConnection *get_connection() const = 0; ... }; ~~~ 文件[WFConneciton.h](../src/factory/WFConnection.h)里,包含了对连接对象的操作接口: ~~~cpp class WFConnection : public CommConnection { public: void *get_context() const; void set_context(void *context, std::function deleter); void set_context(void *context); void *test_set_context(void *test_context, void *new_context, std::function deleter); void *test_set_context(void *test_context, void *new_context); }; ~~~ get_connection()只可在process或callback里调用,而且如果callback里调用,需要检查返回值是否为NULL。 如果成功取得WFConnection对象,就可以操作连接上下文了。连接上下文是一个void *指针。 设置连接上下文可以同时传入deleter函数,在连接被关闭时,deleter被自动调用。 如果调用无deleter参数的接口,可以只设置新的上下文,保持原有的deleter不变。 # 访问连接上下文的时机和并发问题 client task被创建的时候,连接对象没有确定,因此所有client task对连接上下文的使用只有在callback里。 server task可能在两个地方使用连接上下文,process和callback。 在callback里使用连接上下文时,需要考虑并发问题,因为同一个连接,会被多个task复用,并且同时运行到callback。 所以,我们推荐只process函数里访问或修改连接上下文,process过程中连接不会被复用或释放,是最简单安全的方法。 注意,我们指的process只包括process函数内部,在process函数结束后,callback之前,get_connection调用一律返回NULL。 WFConnection的test_set_context(),就是为了解决callback里使用连接上下文是的并发问题,但我们不推荐使用。 总之,如果你不是对系统实现非常了解,请只在server task的process函数里使用连接上下文。 # 示例:减少Http/1.1的请求header传输 http协议可以说是一个连接无状态的协议,同一个连接上,每一次请求都必须发送完整的header。 假设请求里的cookie非常大,那么这显然就增加了很大的数据传输量。我们可以通过server端连接上下文来解决这个问题。 我们约定http request里的cookie,对本连接上所有后续请求有效,后续请求header里可以不再发送cookie。 以下是server端代码: ~~~cpp void process(WFHttpTask *server_task) { protocol::HttpRequest *req = server_task->get_req(); protocol::HttpHeaderCursor cursor(req); WFConnection *conn = server_task->get_connection(); void *context = conn->get_context(); std::string cookie; if (cursor.find("Cookie", cookie)) { if (context) delete (std::string *)context; context = new std::string(cookie); conn->set_context(context, [](void *p) { delete (std::string *)p; }); } else if (context) cookie = *(std::string *)context; ... } ~~~ 通过这种方式,与client端约定好每次只在连接的第一个请求传输cookie,就可以实现流量的节省。 client端的实现需要用到一个新的回调函数,用法如下: ~~~cpp using namespace protocol; void prepare_func(WFHttpTask *task) { if (task->get_task_seq() == 0) task->get_req()->add_header_pair("Cookie", my_cookie); } int some_function() { WFHttpTask *task = WFTaskFactory::create_http_task(...); task->set_prepare(prepare_func); ... } ~~~ 在这个示例中,当http task是连接上的首个请求时,我们设置了cookie。如果不是首个请求,根据约定,不再设置cookie。 另外,prepare函数里,可以安全的使用连接上下文。同一个连接上,prepare不会并发。 workflow-0.11.8/docs/about-counter.md000066400000000000000000000177241476003635400175370ustar00rootroot00000000000000# 关于计数器 计数器是我们框架中一种非常重要的基础任务,计数器本质上是一个不占线程的信号量。 计数器主要用于工作流的控制,包括匿名计数器和命名计数器两种,可以实现非常复杂的业务逻辑。 # 计数器的创建 由于计数器也是一种任务,它的创建同样通过WFTaskFactory来完成,包括两种创建方法: ~~~cpp using counter_callback_t = std::function; class WFTaskFactory { ... static WFCounterTask *create_counter_task(unsigned int target_value, counter_callback_t callback); static WFCounterTask *create_counter_task(const std::string& counter_name, unsigned int target_value, counter_callback_t callback); ... }; ~~~ 每个计数器都包含一个target_value,当计数器的计数到达target_value,callback被调用。 以上两个接口分别产生匿名计数器和命名计数器,匿名计数器直接通过WFCounterTask的count方法来增加计数: ~~~cpp class WFCounterTask { public: virtual void count() { ... } ... } ~~~ 如果创建计数器时,传入一个counter_name,则产生一个命名计数器,可以通过count_by_name函数来增加计数。 # 用匿名计数器实现任务并行 在[并行抓取](./tutorial-06-parallel_wget.md)的示例中,我们通过创建一个ParallelWork来实现多个series并行。 通过ParallelWork和SeriesWork的组合,可以构建任意的串并连图,已经可以满足大多数应用场景需求。 而计数器的存在,可以让我们构建更复杂的任务依赖关系,比如实现一个全连接的神经网络。 以下简单的代码,可代替ParallelWork,实现一个并行的http抓取。 ~~~cpp void http_callback(WFHttpTask *task) { /* Save http page. */ ... WFCounterTask *counter = (WFCounterTask *)task->user_data; counter->count(); } std::mutex mutex; std::condition_variable cond; bool finished = false; void counter_callback(WFCounterTask *counter) { mutex.lock(); finished = true; cond.notify_one(); mutex.unlock(); } int main(int argc, char *argv[]) { WFCounterTask *counter = create_counter_task(url_count, counter_callback); WFHttpTask *task; std::string url[url_count]; /* init urls */ ... for (int i = 0; i < url_count; i++) { task = create_http_task(url[i], http_callback); task->user_data = counter; task->start(); } counter->start(); std::unique_lock lock(mutex); while (!finished) cond.wait(lock); lock.unlock(); return 0; } ~~~ 以上创建一个目标值为url_count的计数器,每个http任务完成之后,调用一次count。 注意,匿名计数器的count次数不可以超过目标值,否则counter可能已经callback销毁了,程序行为无定义。 counter->start()调用可以放在for循环之前。counter只要被创建,就可以调用其count接口,无论counter是否已经启动。 匿名计数器的count接口调用,也可以写成counter->WFCounterTask::count(); 在非常注重性能的应用下可以这么用。 # Server与其它异步引擎结合使用 某些情况下,我们的server可能需要调用非本框架的异步客户端等待结果。简单的方法我们可以在process里同步等待,通过条件变量来唤醒。 这么做的缺点是我们占用了一个处理线程,把其它框架的异步客户端变为同步客户端。但通过counter,我们可以不占线程地等待。 方法很简单: ~~~cpp void some_callback(void *context) { protocol::HttpResponse *resp = get_resp_from_context(context); WFCounterTask *counter = get_counter_from_context(context); /* write data to resp. */ ... counter->count(); } void process(WFHttpTask *task) { WFCounterTask *counter = WFTaskFactory::create_counter_task(1, nullptr); SomeOtherAsyncClient client(some_callback, context); *series_of(task) << counter; } ~~~ 在这里,我们可以把server任务所在的series理解为一个协程,而目标值为1的counter,可以理解为一个条件变量。 Counter的缺点是count操作不传递数据。如果业务有数据传达的需求,可以使用[Mailbox任务](https://github.com/sogou/workflow/blob/master/src/factory/WFTaskFactory.h#L268)。 # 命名计数器 对匿名计数器进行count操作时,直接访问了counter对象指针。这就必然要求在操作时,调用count的次数不超过目标值。 但想象这样一个应用场景,我们同时启动4个任务,只要其中有任意3个任务完成,工作流就可以继续进行。 我们可以用一个目标值为3的计数器,每个任务完成之后,count一次,这样只要任务3个任务完成,计数器就被callback。 但这样的问题是,当第4个任务完成,再调用counter->count()的时候,计数器已经是一个野指针了,程序崩溃。 这时候我们可以用命名计数器来解决这个问题。通过给计数器命名,并通过名字来计数,例如以下实现: ~~~cpp void counter_callback(WFCounterTask *counter) { WFRedisTask *next = WFTaskFactory::create_redis_task(...); series_of(counter)->push_back(next); } int main(void) { WFHttpTask *tasks[4]; WFCounterTask *counter; counter = WFTaskFactory::create_counter_task("c1", 3, counter_callback); counter->start(); for (int i = 0; i < 4; i++) { tasks[i] = WFTaskFactory::create_http_task(..., [](WFHttpTask *task){ WFTaskFactory::count_by_name("c1"); }); tasks[i]->start(); } ... } ~~~ 这个示例中,调起4个并发的http任务,其中3个完成了,立刻启动一个redis任务。实际应用中,可能还需要加入数据传递的代码。 示例中创建命名为"c1"的计数器,在http回调里,使用WFTaskFactory::count_by_name()调用来进行计数。 ~~~cpp class WFTaskFactory { ... static int count_by_name(const std::string& counter_name); static int count_by_name(const std::string& counter_name, unsigned int n); ... }; ~~~ WFTaskFactory::count_by_name方法还可以传入一个整数n,表示这一次操作要增加的计数值,显然: count_by_name("c1")等价于count_by_name("c1", 1)。 如果"c1"计数器不存在(未创建或已经完成),那么对"c1"的操作不产生任何效果,因此不会有匿名计数器野指针的问题。 函数的返回值表示被唤醒的计数器个数。当n大于1时,count_by_name操作可能让多个计数器达到目标值。 # 命名计数器详细行为定义 调用WFTaskFactory::count_by_name(name, n)的时候: * 如果name不存在(未创建或已经完成),无任何行为。 * 如果只有一个名字为name的计数器: * 如果该计数器剩余的值小于或等于n,计数完成,callback被调用,该计数器被销毁。结束。 * 如果计数器剩余值大于n,则计数值加n。结束。 * 如果存在多个同名为name的计数器: * 按照创建顺序,取第一个计数器,假设其剩余值为m: * 如果m值大于n,则计数加n。结束(剩余值为m-n)。 * 如果m小于或等于n,计数完成,callback被调用,第一个计数器被销毁。置n=n-m。 * 如果n为0,结束。 * 如果n大于0,再取出下一个同名计数器,重复整个的操作。 虽然描述很复杂,但总结起来就一句话,按照创建顺序,依次访问所有名字为name的计数器,直到n为0。 也就是说,一次count_by_name(name, n)可以唤醒多个计数器。 用好计数器,可以实现非常复杂的业务逻辑。计数器在我们框架里,往往用于实现异步锁,或者用于任务之间的通道。形态上更像一种控制任务。 workflow-0.11.8/docs/about-dns.md000066400000000000000000000137201476003635400166340ustar00rootroot00000000000000# 关于DNS 当使用域名请求网络时,首先需要通过域名解析获取服务器地址,再使用网络地址进行后续的请求。Workflow已经实现了完备的域名解析和缓存系统,通常来说用户无需知晓内部机制即可流畅地发起网络任务。 ## DNS相关配置 Workflow中的全局配置包括 ~~~cpp struct WFGlobalSettings { struct EndpointParams endpoint_params; struct EndpointParams dns_server_params; unsigned int dns_ttl_default; unsigned int dns_ttl_min; int dns_threads; int poller_threads; int handler_threads; int compute_threads; int fio_max_events; const char *resolv_conf_path; const char *hosts_path; }; ~~~ 其中与域名解析相关的配置项有 * dns_server_params * address_family: 该项会在后续展开说明 * max_connections: 向DNS服务器发送请求的最大并发数,默认为200 * connect_timeout/response_timeout/ssl_connect_timeout: 参考[超时](about-timeout.md)相关说明 * dns_threads: 当使用同步方式实现域名解析时,解析操作会在独立的线程池中执行,该项指定线程池的线程数,默认为4 * dns_ttl_default: 域名解析成功的结果会被放到域名缓存中,该项指定其存活时间,单位为秒,默认值1小时,当解析结果过期后会重新解析以获取最新内容 * dns_ttl_min: 当通信失败时,有可能出现缓存的结果已经失效的情况,该项指定一个较短的存活时间,当通信失败时以更频繁的速率更新缓存,单位为秒,默认值1分钟 * resolv_conf_path: 该文件保存了访问DNS相关的配置,在常见的Linux发行版上通常位于`/etc/resolv.conf`,若该项配置为`NULL`则表示使用多线程同步解析的模式 * hosts_path: 该文件是一个本地的域名查找表,若被解析的域名命中该表则不会向DNS发起请求,在常见的Linux发行版上通常位于`/etc/hosts`,若该项配置为`NULL`则表示不使用查找表 ### resolv.conf扩展功能 Workflow对`resolv.conf`配置文件进行了扩展,用户可以通过修改配置以支持`DNS over TLS(DoT)`功能,**注意**直接修改`/etc/resolv.conf`会影响其他进程,可以将该文件复制一份用于修改,并将Workflow的`resolv_conf_path`配置修改为新文件的路径。例如使用`dnss`协议的`nameserver`会通过SSL进行连接 ~~~bash nameserver dnss://8.8.8.8/ nameserver dnss://[2001:4860:4860::8888]/ ~~~ ### Address Family 在某些网络环境下,虽然本机支持IPv6,但因未被分配公网IPv6地址而无法与外部通信(例如本地IPv6地址以`fe80`开始)。此时可以将`endpoint_params.address_family`设置为`AF_INET`来强制域名解析时仅解析IPv4地址。同样的,`resolv.conf`文件中可能同时指定了`nameserver`的IPv4地址和IPv6地址,此时可以将`dns_server_params.address_family`设置为`AF_INET`或`AF_INET6`来强制仅使用IPv4或IPv6地址来访问DNS。 ### 使用Upstream配置 全局配置默认对每个域名生效,若需要对某些域名单独指定不同的配置,则可使用[Upstream](./about-upstream.md#Address属性)功能。使用Upstream可以单独指定`dns_ttl_default`、`dns_ttl_min`配置项,以及通过`endpoint_params.address_family`单独指定该域名使用的IP地址类别。 ## 域名解析与缓存策略 网络任务通常需要通过域名解析获取到需要访问的IP地址,Workflow中域名解析相关策略如下 1. 检查域名缓存是否有该域名对应的IP地址,若有缓存且未过期,则使用该组IP地址 2. 检查域名是否为IPv4、IPv6地址或`Unix Domain Socket`,若是则直接使用该地址,无需发起域名解析 3. 检查`hosts_path`文件中是否包含该域名对应的IP地址,若有则直接使用该地址 4. 获取异步锁,保证同一域名的解析请求在同一时刻仅发起一次,并向DNS发起解析请求 5. 解析成功后会将解析结果保存到当前进程的域名缓存中,以供下次使用,并释放异步锁 6. 解析失败后会释放异步锁且将失败原因通知给等在同一个异步锁上的所有任务,通知结束后再发起的新的任务则会再次请求DNS 许多需要大量发起网络请求的场景都会配备域名缓存组件,如果每次发起网络任务时都向DNS发起解析请求,则DNS必然会不堪重负。Workflow设置了缓存存活时长(dns_ttl_default和dns_ttl_min)来保证缓存会在合理的时间后过期,以及时更新域名的解析结果。当某个域名的缓存项过期后,首先发现过期的任务会将其存活时间延长5秒并向DNS发起解析请求,5秒内同一域名上的请求会直接使用缓存的DNS解析结果,而无需等待本次解析结束。 异步锁机制可以保证**同一域名**的解析请求在同一时刻仅发起一次,在没有锁保护的情况下,若短时间内对同一域名发起大量网络任务,每个任务都会因无法从缓存中获取结果而向DNS发起解析请求,这会对DNS带来很大且不必要的负担。这里的同一域名表示的是`(host, port, family)`三元组,若通过Upstream的方式对某域名分别要求只使用IPv4和IPv6,则他们会被不同的异步锁保护,也就有可能同时发起DNS请求。 ### 异步域名解析 Workflow实现了完备的DNS任务(参考[dns_cli](./tutorial-17-dns_cli.md)),若指定了`resolv_conf_path`配置项,则向DNS发起域名解析时会使用异步请求的方式进行,在类Unix系统下,Workflow默认使用`/etc/resolv.conf`作为该配置的值。异步域名解析不会阻塞任何线程,也不会独占线程池,可以更高效地完成域名解析的任务。 ### 同步域名解析 若指定`resolv_conf_path`为`NULL`,则会通过调用`getaddrinfo`函数来实现同步域名解析,该方式会使用独立的线程池,其线程数通过`dns_threads`参数配置。若短时间内需要发起较多的域名解析请求,则同步的方式会带来较大的延迟。 workflow-0.11.8/docs/about-error.md000066400000000000000000000074461476003635400172110ustar00rootroot00000000000000# 关于错误处理 任何软件系统里,错误处理都是一个重要而复杂的问题。在我们框架内部,错误处理可以说是无处不在并且极其繁琐的。 而在我们暴露给用户的接口里,我们尽可能地让事情变简单,但用户还是不可避免地需要了解一些错误信息。 ### 禁用C++异常 我们框架内不使用C++异常,用户编译自己代码的时候,最好也加上-fno-exceptions标志,以减少代码大小。 参考业界通用做法,我们会忽略new操作失败的可能,并且内部也避免用new去分配大块内存。而C语言风格的内存分配则是有查错的。 ### 关于工厂函数 从之前的实例中我们看到,所有的task,series都是从WFTaskFactory或Workflow这两个工厂类产生的。 这些工厂类,以及我们以后可能遇到的更多的工厂类接口,都是确保成功的。也就是说,一定不会返回NULL。用户无需对返回值做检查。 为了达到这个目的,当URL不合法时,工厂也能正常产生task。并且在任务的callback里再得到错误。 ### 任务的状态和错误码 在之前的示例里,我们经常在callback里看到这样的代码: ~~~cpp void callback(WFXxxTask *task) { int state = task->get_state(); int error = task->get_error(); ... } ~~~ 其中,state代表任务的结束状态,在[WFTask.h](../src/factory/WFTask.h)文件中,可以看到所有可能的状态值: ~~~cpp enum { WFT_STATE_UNDEFINED = -1, WFT_STATE_SUCCESS = CS_STATE_SUCCESS, WFT_STATE_TOREPLY = CS_STATE_TOREPLY, /* for server task only */ WFT_STATE_NOREPLY = CS_STATE_TOREPLY + 1, /* for server task only */ WFT_STATE_SYS_ERROR = CS_STATE_ERROR, WFT_STATE_SSL_ERROR = 65, WFT_STATE_DNS_ERROR = 66, /* for client task only */ WFT_STATE_TASK_ERROR = 67, WFT_STATE_ABORTED = CS_STATE_STOPPED /* main process terminated */ }; ~~~ ##### 需要关注的几个状态: * SUCCESS:任务成功。client接收到完整的回复,或server把回复完全写进入发送缓冲(但不能确保对方一定能收到)。 * SYS_ERROR: 系统错误。这种情况,task->get_error()得到的是系统错误码errno。 * 当get_error()得到ETIMEDOUT,可以调用task->get_timeout_reason()进一步得到超时原因。 * DNS_ERROR: DNS解析错误。get_error()得到的是getaddrinfo()调用的返回码。关于DNS,有一篇文档专门说明[about-dns.md](./about-dns.md)。 * server任务永远不会有DNS_ERROR。 * SSL_ERROR: SSL错误。get_error()得到的是SSL_get_error()的返回值。 * 目前SSL错误信息没有做得很全,得不到ERR_get_error()的值。所以,基本上get_error()返回值也就三个可能: * SSL_ERROR_ZERO_RETURN, SSL_ERROR_X509_LOOKUP, SSL_ERROR_SSL。 * 更加详细的SSL错误信息,我们在后续版本会考虑加入。 * TASK_ERROR: 任务错误。常见的例如URL不合法,登录失败等。get_error()的返回值可以在[WFTaskError.h](../src/factory/WFTaskError.h)中查看。 ##### 用户一般无需关注的几个状态: * UNDEFINED: 刚创建完,还没有运行的client任务,状态是UNDEFINED。 * TOREPLY: server任务回复之前,没有被调用过task->noreply(),都是TOREPLY状态。 * NOREPLY: server任务被调用了task->noreply()之后,一直是NOREPLY状态。callback里也是这个状态。连接会被关闭。 ### 其它错误处理需求 除了任务本身的错误处理,各种具体协议的消息接口上,也会有判断错误的需要。一般这些接口都通过返回false来表示错误,并且通过errno传递错误原因。 此外,一些更复杂的用法,可能需要接触到更复杂一点的错误信息。我们在具体的文档里再做介绍。 workflow-0.11.8/docs/about-exit.md000066400000000000000000000127161476003635400170250ustar00rootroot00000000000000# 关于程序退出 由于我们的大多数调用都是非阻塞的,所以在之前的示例里我们都需要用一些机制来防止main函数提前退出。 例如wget示例中等待用户的Ctrl-C,或者像parallel_wget在所有抓取结束之后唤醒主线程。 而在几个server的示例中,stop()操作是阻塞的,可以确保所有server task的正常结束,主线程可安全退出。 # 程序安全退出的原则 一般情况下,用户只要正常写程序,模仿示例中的方法,不太会有什么关于退出的疑惑。但这里还是需要把程序正常退出的条件定义好。 * 用户不可以在callback或process等任何回调函数里调用系统的exit()函数,否则行为无定义。 * 主线程可以安全结束(main函数调用exit()或return)的条件是所有任务已经运行到callback,并且没有新的任务被调起。 * 我们所有的示例都符合这个假设,在callback里唤醒main函数。这是安全的,不用担心main返回的时候,callback还没结束的情况。 * ParallelWork是一种task,也需要运行到callback。 * 这一条规则某下情况下可以违反,我们将在下一节解释。 * 所有server必须stop完成,否则行为无定义。因为stop操作用户都会调,所以一般server程序不会有什么退出方面的问题。 * server的stop会等待所有server任务所在series结束。但如果用户在process直接start一个新任务,则需要考虑任务结束的问题。 # 为什么需要等待运行中的任务callback?能不能提前结束程序? 首先解释一下需要等待任务callback再结束程序的原因。在大多数情况下,我们通过任务工厂产生的任务,都是一个复合任务。 http抓取任务为例,一个http任务可能需要先解析dns,再发起http抓取。如遇到302重定向,可能需要再次dns。任务失败可能还会重试。 也就是说,我们一个异步任务可能包含多个异步过程,但对用户完全无感。而内部每个异步过程之间,并不会检查程序是否已经退出。 如果用户明确知道一个任务是原子任务,例如以IP地址(或肯定能dns cache命中)创建http任务,并且无重定向或重试。 那么,这个任务可以被程序退出打断并提前来到callback,callback里任务的状态是WFT_STATE_ABORTED。 例如以下程序是绝对安全的: ~~~cpp void callback(WFHttpTask *task) { // 这里打印的结果大概率是2,WFT_STATE_ABORTED。 printf("state = %d\n", task->get_state()); } int main() { WFHttpTask *task = WFTaskFactory::create_http_task("https://127.0.0.1/", 0, 0, callback); task->start(); // 这里直接结束程序 return 1; } ~~~ 如果dns cache命中,也是安全的。因为内部无需再发起一个dns异步任务了。例如: ~~~cpp WFFacilities::WaitGroup wg(1); void callback_normal(WFHttpTask *task) { wg.done(); } void callback_abort(WFHttpTask *task) { // 这里打印的结果大概率是2,WFT_STATE_ABORTED。 printf("state = %d\n", task->get_state()); } int main() { WFHttpTask *task = WFTaskFactory::create_http_task("https://www.sogou.com/", 3, 2, callback_normal); task->start(); // 等待第一个访问www.sogou.com的任务结束。 wg.wait(); // 第二次访问www.sogou.com, dns信息已经被cache。 WFHttpTask *task = WFTaskFactory::create_http_task("https://www.sogou.com/", 0, 0, callback_abort); task->start(); // 这里直接结束程序 return 1; } ~~~ 所以,对于网络任务而言,只要能确定是一个原子任务,都可以被程序结束打断。这个原则可以扩展到任何类型的任务。 例如,定时器任务是一个就原子任务,以下程序也是绝对安全的: ~~~cpp void callback(WFTimerTask *task) { // 这里打印的结果肯定是2,WFT_STATE_ABORTED。 printf("state = %d\n", task->get_state()); } int main() { WFTimerTask *task = WFTaskFactory::create_timer_task(1000000, callback); task->start(); // 这里直接结束程序 return 1; } ~~~ 在[关于定时器](https://github.com/sogou/workflow/blob/master/docs/about-timer.md)的文档里,我们将会详细展开描述。 此外,单线程的计算任务,文件IO任务,也可以在callback之前直接结束程序。 其中,已经在执行计算的计算任务,程序会等待计算结束,最终以SUCCESS状态callback。还未被调起的,则以ABORTED状态退出。 文件IO任务,只要已经start,肯定会等待IO完成。因此直接退出程序完全安全。 # 关于OpenSSL 1.1版本在退出时的内存泄露 我们发现某些openssl1.1版本,存在退出时内存释放不完全的问题,通过valgrind内存检查工具可以看出内存泄露。 这个问题只有在用户使用了SSL,例如抓取了https网页时才会发生,而且一般情况下用户可以忽略这个泄露。 如果一定要解决,方法如下: ~~~cpp #include int main() { #if OPENSSL_VERSION_NUMBER >= 0x10100000L OPENSSL_init_ssl(0, NULL); #endif ... } ~~~ 也就是说在使用我们的库之前,先初始化openssl。如果你有需要也可以同时配置openssl的参数。 注意这个函数只在openssl1.1以上版本才有提供,所以调用之前需要先判断openssl版本。 这个内存泄露与openssl1.1的内存释放原理有关。我们提供的这个方案可以解决这个问题(但我们还是建议用户忽略)。 workflow-0.11.8/docs/about-go-task.md000066400000000000000000000144371476003635400174230ustar00rootroot00000000000000# 关于go task 我们提供了另一种更简单的使用计算任务的方法,模仿go语言实现的go task。 使用go task来实计算任务无需定义输入与输出,所有数据通过函数参数传递。 # 创建go task ~~~cpp class WFTaskFactory { ... public: template static WFGoTask *create_go_task(const std::string& queue_name, FUNC&& func, ARGS&&... args); }; ~~~ 函数参数的queue_name为计算队列名,其作用在之前示例文档中有过介绍。 func可以是函数指针,函数对象,仿函数,lambda函数,类的成员函数等任意可调用对象。 args为func的参数列表。注意当func是一个类的非静态成员函数时,args的第一个成员必须是对象地址。 # 示例 我们想异步的运行一个加法函数:void add(int a, int b, int& res); 并且我们还想在函数运行结束的时候打印出结果。于是可以这样实现: ~~~cpp #include #include #include "workflow/WFTaskFactory.h" #include "workflow/WFFacilities.h" void add(int a, int b, int& res) { res = a + b; } int main(void) { WFFacilities::WaitGroup wait_group(1); int a = 1; int b = 1; int res; WFGoTask *task = WFTaskFactory::create_go_task("test", add, a, b, std::ref(res)); task->set_callback([&](WFGoTask *task) { printf("%d + %d = %d\n", a, b, res); wait_group.done(); }); task->start(); wait_group.wait(); return 0; } ~~~ 以上的示例异步运行一个加法,打印结果并退出程序。go task的使用与其它的任务没有多少区别,也有user_data域可以使用。 唯一一点不同,是go task创建时不传callback,但和其它任务一样可以set_callback。 如果go task函数的某个参数是引用,需要使用std::ref,否则会变成值传递,这是c++11的特征。 # 把workflow当成线程池 用户可以只使用go task,这样可以将workflow退化成一个线程池,而且线程数量默认等于机器cpu数。 但是这个线程池比一般的线程池又有更多的功能,比如每个任务有queue name,任务之间还可以组成各种串并联或更复杂的依赖关系。 # 带执行时间限制的go task 通过create_timedgo_task接口(这里无法重载create_go_task接口),可以创建带时间限制的go task: ~~~cpp class WFTaskFactory { /* Create 'Go' task with running time limit in seconds plus nanoseconds. * If time exceeded, state WFT_STATE_SYS_ERROR and error ETIMEDOUT will be got in callback. */ template static WFGoTask *create_timedgo_task(time_t seconds, long nanoseconds, const std::string& queue_name, FUNC&& func, ARGS&&... args); }; ~~~ 相比创建普通的go task,create_timedgo_task函数需要多传两个参数,seconds和nanoseconds。 如果func的运行时间到达seconds+nanosconds时限,task直接callback,且state为WFT_STATE_SYS_ERROR,error为ETIMEDOUT。 注意,框架无法中断用户执行中的任务。func依然会继续执行到结束,但不会再次callback。另外,nanoseconds取值区间在\[0,10亿)。 另外,当我们给go task加上了运行时间限制,callback的时机可能会先于func函数的结束,任务所在series可能也会先于func结束。 如果我们在func里访问series,可能就是一个错误了。例如: ~~~cpp void f(SeriesWork *series) { series->set_context(...); // 错误。当f是一个带超时的go task,此时series可能已经失效了。 } int http_callback(WFHttpTask *task) { SeriesWork *series = series_of(task); WFGoTask *go = WFTaskFactory::create_timedgo_task(1, 0, "test", f, series); // 1秒超时的go task series_of(task)->push_back(go); } ~~~ 这也是为什么,我们不推荐在计算任务的执行函数里,对任务所在的series进行操作。对series的操作,应该在callback里进行,例如: ~~~cpp int main() { WFGoTask *task = WFTaskFactory::create_timedgo_task(1, 0, "test", f); task->set_callback([](WFGoTask *task) { SeriesWork *series = series_of(task): void *context = series->get_context(); if (task->get_state() == WFT_STATE_SUCCESS) // 成功执行完 { ... } else // state == WFT_STATE_SYS_ERROR && error == ETIMEDOUT // 超过运行时间限制 { ... } }); } ~~~ 但是,在计算函数里使用task,是安全的。所以,可以使用task->user_data,在计算函数和callback之间传递数据。例如: ~~~cpp int main() { WFGoTask *task = WFTaskFactory::create_timedgo_task(1, 0, "test", [&task]() { task->user_data = (void *)123; }); task->set_callback([](WFGoTask *task) { SeriesWork *series = series_of(task): void *context = series->get_context(); if (task->get_state() == WFT_STATE_SUCCESS) // 成功执行完 { int result = (int)task->user_data; } else // state == WFT_STATE_SYS_ERROR && error == ETIMEDOUT // 超过运行时间限制 { ... } }); task->start(); ... } ~~~~ # 重置go task的执行函数 在某些时候,我们想在go task的执行函数里访问task,如上面的例子,将计算结果写入task的user_data域。 上例中,我们使用了引用捕获。但明显引用捕获会有一些问题。比如task本身的生命周期。我们更希望在执行函数里直接捕获go task指针。 直接进行值捕获明显是错误的,例如: ~~~cpp WFGoTask *task = WFTaskFactory::create_timedgo_task(1, 0, "test", [task]() { task->user_data = (void *)123; }); ~~~ 这段代码并不能在lambda函数里得到task指针,因为捕获执行时,task还没有赋值。但我们可以通过以下的代码,实现这个需求: ~~~cpp WFGoTask *task = WFTaskFactory::create_timedgo_task(1, 0, "test", nullptr); // 执行函数可以初始化为nullptr WFTaskFactory::reset_go_task(task, [task]() { task->user_data = (void *)123; }); ~~~ WFTaskFactory::reset_get_task()函数,用于重置go task的执行函数。 因为task已经创建完毕,这时候在lambda函数里捕获task,就是一个正确的行为了。 workflow-0.11.8/docs/about-module.md000066400000000000000000000071361476003635400173410ustar00rootroot00000000000000# 关于模块任务 我们的任务流是以task为元素。但很多情况下,用户需要模块级的封装,比如几个task完成一个特定的功能。 用原有的方法,就不得不让最后一个task的callback来衔接下一个任务,或者填写server任务的resp。这样不太合理。 因此,我们引入了WFModuleTask,方便用户封装模块,降低不同功能模块之间task的耦合。 # 模块任务的创建 我们把模块定义成一种特殊的任务,WFModuleTask。模块的内部包括一个sub_series用于运行模块内的任务。 对任务来讲,它无需关心自己是否运行在模块内。因为模块内的sub_series和普通series没有任何区别。 在[WFTaskFactory.h](/src/factory/WFTaskFactory.h)里,包括了包括了模块任务的创建接口: ~~~cpp using module_callback_t = std::function; class WFTaskFactory { static WFModuleTask *create_module_task(SubTask *first, module_callback_t callback); }; ~~~ create_module_task()的第一个参数first代表模块首任务,这与创建series类似。 module callback参数要求是const指针。这主要是防止用户在callback里,继续向module中添加任务。 # WFModuleTask的主要接口 因为我们把模块也定义成这一种任务,所以,可以像使用其它任务一样使用模块。但模块没有state和error域。 在[WFTask.h](/src/factory/WFTask.h)里,定义了WFModuleTask类。 ~~~cpp class ModuleTask : public ParallelTask, protected SeriesWork // 不必关注这个派生关系 { public: void start() { .. } void dismiss() { ... } public: SeriesWork *sub_series() { return this; } const SeriesWork *sub_series() const { return this; } public: void *user_data; }; ~~~ module特有的sub_series接口返回module内任务运行的series。module本质上是一个子任务流。 sub_series也是一个普通的series,用户可以调用它的set_context(),get_context(),push_back()等函数。 但我们不太建议给sub_series设置callback,因为没有什么必要,使用module的callback就可以了。 注意,在module的callback参数表,是const WFModuleTask \*,也就只能得到一个const的sub_series。 因此,在模块任务的callback里,只能调用sub_series的get_context()得到series上下文。 # 示例 在一个http server的处理逻辑中,我们把所有处理逻辑设计成一个模块。 ~~~cpp struct ModuleCtx { std::string body; }; void http_callback(WFHttpTask *http_task) { SeriesWork *series = series_of(http_task); // 这个series就是module的sub_series。 struct ModuleCtx *ctx = (struct ModuleCtx *)series->get_context(); const void *body; size_t size; if (http_task->get_resp()->get_parsed_body(&body, &size)) { ctx->body.assign(body, size); } ParallelWork *pwork = Workflow::create_parallel_work(…);// 做一些别的操作 series->push_back(pwork); } void process(WFHttpTask *server_task) { WFHttpTask *http_task = WFTaskFactory::create_http_task(…, http_callback); WFModuleTask *module = WFTaskFactory::create_module_task(http_task, [server_task](const WFModuleTask *mod) { struct ModuleCxt *ctx = (struct ModuleCtx *)mod->sub_series()->get_context(); server_task->get_resp()->append_output_body(ctx->body); delete ctx; }); module->sub_series()->set_context(new ModuleCtx); series_of(server_task)->push_back(module); } ~~~ 通过这个方法,module里的任务只需操作series context,最终由module的callback汇总填写resp。任务耦合性大幅降低。 workflow-0.11.8/docs/about-resource-pool.md000066400000000000000000000135311476003635400206460ustar00rootroot00000000000000# 资源池 在我们用workflow写异步程序时经常会遇到这样一些场景: * 任务运行时需要先从某个池子里获得一个资源。任务运行结束,则会把资源放回池子,让下一个需要资源的任务运行。 * 网络通信时需要对某一个或一些通信目标做总的并发度限制,但又不希望占用线程等待。 * 我们有许多随机到达的任务,处在不同的series里。但这些任务必须**串行**的运行。 所有这些需求,都可以用资源池模块来解决。我们的[WFDnsResolver](https://github.com/sogou/workflow/blob/master/src/nameservice/WFDnsResolver.cc)就是通过这个方法来实现对dns server的并发度控制的。 # 资源池的接口 在[WFResourcePool.h](https://github.com/sogou/workflow/blob/master/src/factory/WFResourcePool.h)里,定义了资源池模块的接口: ~~~cpp class WFResourcePool { public: WFConditional *get(SubTask *task, void **resbuf); WFConditional *get(SubTask *task); void post(void *res); ... protected: virtual void *pop() { return this->data.res[this->data.index++]; } virtual void push(void *res) { this->data.res[--this->data.index] = res; } ... public: WFResourcePool(void *const *res, size_t n); WFResourcePool(size_t n); ... }; ~~~ #### 构造函数 第一个构造函数接受一个资源数组,长度为n。数组每个元素为一个void \*。内部会再分配一份相同大小的内存,把数组复制走。 如果你的初始资源都是nullptr,那么你可以使用第二个构造函数,只需要传n,而无需先建立一个全部为nullptr的指针数组。 大概看看内部实现就明白了: ~~~cpp void WFResourcePool::create(size_t n) { this->data.res = new void *[n]; this->data.value = n; ... } WFResourcePool::WFResourcePool(void *const *res, size_t n) { this->create(n); memcpy(this->data.res, res, n * sizeof (void *)); } WFResourcePool::WFResourcePool(size_t n) { this->create(n); memset(this->data.res, 0, n * sizeof (void *)); } ~~~ #### 使用接口 用户使用get()接口,把任务打包成一个conditional。conditional是一个条件任务,条件满足时运行其包装的任务。 get()接口可包含第二个参数是一个void \*\*resbuf,用于保存所获得的资源。 接下来,用户只需要用这个conditional取代原来的任务使用就好了,可以start或串进任务流。 注意conditional是在它被执行时去尝试获得资源的,而不是在它被创建的时候。要不然的话,以下代码就会被卡死: ~~~cpp WFResourcePool pool(1); int f() { WFHttpTask *t1 = WFTaskFactory::create_http_task(..., [](void *){pool.post(nullptr);}); WFHttpTask *t2 = WFTaskFactory::create_http_task(..., [](void *){pool.post(nullptr);}); WFConditional *c1 = pool.get(t1, &t1->user_data); // 用user_data来保存res是一种实用方法。 WFConditional *c2 = pool.get(t2, &t2->user_data); c2->start(); // wait for t2 finish here. ... c1->start(); ... } ~~~ 以上代码c1先创建,等待t2结束后才运行。这里并不会出现c2卡死,因为conditional是在执行时才获得资源的。 当用户对资源使用完毕(一般在任务callback里),需要通过post()接口把资源放回池子。 post()时的res参数,**无需**与get()得到res的一致。 #### 派生 从上面的pop()和push()函数我们可以看到,我们对资源的使用默认是FILO,即先进后出的。 使用FILO的原因是,大多数场景下,刚刚被释放的资源应该优先被复用。 但是,用户可以通过派生的方式,非常简单的实现一个FIFO资源池。只需要重写pop()和push()两个virtual函数即可。 如果需要,你还可以实现可动态扩展和收缩的资源池。 # 示例 我们准备抓取一份URL列表,但要求总的并发度不超过max_p。我们当然可以用parallel来实现,但使用资源池可以更简单: ~~~cpp int fetch_with_max(std::vector& url_list, size_t max_p) { WFResourcePool pool(max_p); for (std::string& url : url_list) { WFHttpTask *task = WFTaskFactory::create_http_task(url, [&pool](WFHttpTask *task) { pool.post(nullptr); }); WFConditional *cond = pool.get(task); // 无需保存res,可以不传resbuf参数。 cond->start(); } // wait_here... } ~~~ # 消息队列 消息队列是一种比资源使用方法类似的组件。它们的区别在于: * 资源池的总资源数量是固定的,在创建时就已经确定。而消息队列的长度则不受限制。 * 资源池的存取方式是先进后出,刚刚释放的资源会先被复用。而消息队列则是先进先出。 * 资源池使用方式是先获取,后归还。没有获取就直接归还资源,可能导致缓冲区溢出。消息队列没有这样的约束。 * 实现上,资源池使用的是数组,消息队列使用链表。总体来讲,在实现和使用上,消息队列都比资源池简单一些。 # 消息队列接口 在[WFMessageQueue.h](https://github.com/sogou/workflow/blob/master/src/factory/WFMessageQueue.h)里,定义了消息队列模块的接口: ~~~cpp class WFMessageQueue { public: WFConditional *get(SubTask *task, void **msgbuf); WFConditional *get(SubTask *task); void post(void *msg); ... public: WFMessageQueue(); ... }; ~~~ 由于了解过资源池的用法,消息队列的使用方式我们也就无需再详细展开。模式和资源池一样,都是在获得消息(或资源)时,任务被拉起。 消息队列的get和post接口,无需像资源池一样遵循先获取再放回的原则,任何任务都可以随时从队列中存取消息。 如果有需要,用户同样可以派生WFMessageQueue类,实现先进后出的消息读取模式。 workflow-0.11.8/docs/about-selector.md000066400000000000000000000077761476003635400177060ustar00rootroot00000000000000# 关于Selector任务 我们业务中经常有一些需求,从几个异步分支中选择第一个成功完成的结果进行处理,丢弃其它结果。 Selector任务就是为了上述这种多选一场景而设计的。 # Selector解决的问题 常见的多选一场景例如: * 向多个下游发送网络请求,只要任意一个下游返回正确结果,工作流程就可以继续。 * 执行一组复杂的操作,操作执行完成或整体超时,流程都会继续。 * 并行计算中,任何一个线程计算出预期的结果即完成,例如MD5碰撞计算。 * 网络应用中的‘backup request’,也可以用selector配合timer来实现。 在selector任务被引入之前,这些场景很难被很好解决,涉及到任务生命周期以及丢弃结果的资源回收等问题。 # 创建Selector任务 Selector也是一种任务,所以一般由WFTaskFactory里的工厂函数产生: ~~~cpp using selector_callback_t = std::function; class WFTaskFactory { public: static WFSelectorTask *create_selector_task(size_t candidates, selector_callback_t callback); }; ~~~ 其中,candidates参数代表从多少个候选路径中选择。Selector任务创建后,必须有candidates次被提交才会被销毁。 因此,用户可以放心的(也是必须的)向selector提交candidates次,无需要担心selector的生命周期问题。 # Selector类的接口 WFSelectorTask类包括两个主要接口。其中,对提交者来讲,只需要关注submit函数。对于等待者,只需使用到get_message。 ~~~cpp class WFSelectorTask : public WFGenericTask { public: virtual int submit(void *msg); void *get_message() const; }; ~~~ 当第一个非空指针的msg被提交,submit函数返回1表示接受。随后的submit调用都返回0代表消息被拒绝。 Selector运行后接收到一个有效消息就进入callback了,但在收到所有submit之前,不会被销毁。 注意空指针永远不会被接受,所以submit一个NULL将返回0。一般来讲,submit(NULL)用于表示这个分支失败了。 如果所有候选都提交了NULL,selector运行到callback时,state=WFT_STATE_SYS_ERROR, error=ENOMSG。 作为等待者,在selector的callback里调用另外一个接口get_message()就可以得到被成功接受的消息了。 # 示例 我们同时抓取两个http网页,并设置一个超时。当任意一个先抓取成功或超时,打印出抓取成功的URL或出错信息。 示例中使用wait group来保证两个抓取任务已经结束才退出程序。而timer可以被程序退出打断,无需等待。 ~~~cpp #include #include #include "workflow/WFTaskFactory.h" #include "workflow/WFFacilities.h" WFSelectorTask *selector; WFFacilities::WaitGroup wait_group(2); void http_callback(WFHttpTask *t) { if (t->get_state() == WFT_STATE_SUCCESS) selector->submit(t->user_data); else selector->submit(NULL); wait_group.done(); } int main(int argc, char *argv[]) { if (argc != 4) { fprintf(stderr, "USAGE: %s \n", argv[0]); exit(1); } selector = WFTaskFactory::create_selector_task(3, [](WFSelectorTask *selector) { void *msg = selector->get_message(); if (msg) printf("%s\n", (char *)msg); else printf("failed\n"); }); auto *t = WFTaskFactory::create_http_task(argv[1], 0, 0, http_callback); t->user_data = argv[1]; t->start(); t = WFTaskFactory::create_http_task(argv[2], 0, 0, http_callback); t->user_data = argv[2]; t->start(); auto *timer = WFTaskFactory::create_timer_task(atoi(argv[3]), 0, [](WFTimerTask *timer){ if (timer->get_state() == WFT_STATE_SUCCESS) selector->submit((void *)"timeout"); else selector->submit(NULL); }); timer->start(); selector->start(); wait_group.wait(); return 0; } ~~~ workflow-0.11.8/docs/about-service-governance.md000066400000000000000000000224011476003635400216310ustar00rootroot00000000000000# 关于服务治理 我们拥有一套完整的机制,来管理我们所依赖的服务。这套机制包括以下的几个功能: * 用户级DNS。 * 服务地址的选取 * 包括多种选取机制,如权重随机,一致性哈希,用户指定选取方式等。 * 服务的熔断与恢复。 * 负载均衡。 * 单个服务的独立参数配置。 * 服务的主备关系等。 所有这些功能都依赖于我们的upstream子系统。利用好这个系统,我们可以轻易地实现更复杂的服务网格功能。 # upstream名 upstream名相当于程序内部的域名,但相比一般的域名,upstream拥有更多的功能,包括: * 域名通常只能指向一组ip地址,upstream名可以指向一组ip地址或域名。 * upstream指向的对象(域名或ip),可以包括端口信息。 * upstream有管理和选择目标的强大功能,每个目标可以包含大量属性。 * upstream的更新,是实时而且完全线程安全的,而域名的DNS信息,并不能实时更新。 实现上,如果无需访问外网,用upstream可以完全代替域名和DNS。 # upstream的创建与删除 在[UpstreamManager.h](../src/manager/UpstreamManager.h)里,包括几个upstream创建接口: ~~~cpp using upstream_route_t = std::function; class UpstreamManager { public: static int upstream_create_consistent_hash(const std::string& name, upstream_route_t consitent_hash); static int upstream_create_weighted_random(const std::string& name, bool try_another); static int upstream_create_manual(const std::string& name, upstream_route_t select, bool try_another, upstream_route_t consitent_hash); static int upstream_delete(const std::string& name); ... }; ~~~ 三个函数创建分别为3种类型的upstream:一致性hash,权重随机和用户手动选取。 参数name为upstream名,创建之后,就和域名一样的使用了。 consistent_hash和select参数,都是一个类型为upstream_route_t的std::function,用于指定路由方式。 而try_another表示,如果选取到的目标不可用(熔断),是否继续尝试找到一个可用目标。consistent_hash模式没有这个属性。 upstream_route_t参数接收的3个参数分别是url里的path, query和fragment部分。例如URL为:http://abc.com/home/index.html?a=1#bottom 则这三个参数分别为"/home/index.html", "a=1"和"bottom"。用户可以根据这三个部分,选择目标服务器,或者进行一致性hash。 注意,以上接口中,consistent_hash参数都可以传nullptr,我们将使用默认的一致性哈希算法。 # 示例1:权重分配 我们想把50%访问www.sogou.com的请求,打到127.0.0.1:8000和127.0.0.1:8080两个地址,并且让他们的负载为1:4。 我们无需要关心域名www.sogou.com之下,有多少个ip地址。总之实际域名会接收50%的请求。 ~~~cpp #include "workflow/UpstreamManager.h" #include "workflow/WFTaskFactory.h" int main() { UpstreamManager::upstream_create_weighted_random("www.sogou.com", false); struct AddressParams params = ADDRESS_PARAMS_DEFAULT; params.weight = 5; UpstreamManager::upstream_add_server("www.sogou.com", "www.sogou.com", ¶ms); params.weight = 1; UpstreamManager::upstream_add_server("www.sogou.com", "127.0.0.1:8000", ¶ms); params.weight = 4; UpstreamManager::upstream_add_server("www.sogou.com", "127.0.0.1:8080", ¶ms); WFHttpTask *task = WFTaskFactory::create_http_task("http://www.sogou.com/index.html", ...); ... } ~~~ 请注意,以上这些函数可以在任何场景下调用,完全线程安全,并实时生效。 另外,由于我们一切协议,包括用户自定义协议都有URL,所以upstream功能可作用于一切协议。 # 示例2:手动选择 同样是上面的例子,我们想让url里,query为"123"的请求,打到127.0.0.1:8000,如果是"abc",打到8080端口,其它打正常域名。 ~~~cpp #include "workflow/UpstreamManager.h" #include "workflow/WFTaskFactory.h" int my_select(const char *path, const char *query, const char *fragment) { if (strcmp(query, "123") == 0) return 1; else if (strcmp(query, "abc") == 0) return 2; else return 0; } int main() { UpstreamManager::upstream_create_manual("www.sogou.com", my_select, false, nullptr); UpstreamManager::upstream_add_server("www.sogou.com", "www.sogou.com"); UpstreamManager::upstream_add_server("www.sogou.com", "127.0.0.1:8000"); UpstreamManager::upstream_add_server("www.sogou.com", "127.0.0.1:8080"); /* This URL will route to 127.0.0.1:8080 */ WFHttpTask *task = WFTaskFactory::create_http_task("http://www.sogou.com/index.html?abc", ...); ... } ~~~ 由于我们原生提供了redis和mysql协议,用这个方法,可以极其方便的实现数据库的读写分离功能(注:非事务的操作)。 以上两个例子,upstream名用的是www.sogou.com,这本身也是一个域名。当然用户可以更简单的用字符串sogou,这样创建任务时: ~~~cpp WFHttpTask *task = WFTaskFactory::create_http_task("http://sogou/home/1.html?abc", ...); ~~~ 总之url的host部分,如果是一个已经创建的upstream,则会被当作upstream使用。 # 示例3:一致性hash 这个场景里,我们要从10个redis实例中,随机选择一台机器通信。但保证同一个url肯定访问一个确定的目标。方法很简单: ~~~cpp int main() { UpstreamManager::upstream_create_consistent_hash("redis.name", nullptr); UpstreamManager::upstream_add_server("redis.name", "10.135.35.53"); UpstreamManager::upstream_add_server("redis.name", "10.135.35.54"); UpstreamManager::upstream_add_server("redis.name", "10.135.35.55"); ... UpstreamManager::upstream_add_server("redis.name", "10.135.35.62"); auto *task = WFTaskFactory::create_redis_task("redis://:mypassword@redis.name/2?a=hello#111", ...); ... } ~~~ 我们的redis任务并不识别query部分,用户可以随意填写。path部分的2为redis库号。 这个时候,consistent_hash函数将得到"/2","a=hello"和"111"三个参数,但因为我们用nullptr,默认一致性hash将被调用。 upstream里的服务器没有指定端口号,于是将使用url里的端口。redis默认为6379。 consitent_hash并没有try_another选项,如果目标熔断,将自动选取另一个。相同url还将得到相同选择(cache友好)。 # upstream server的参数 示例1中,我们通过params参数设置了server的权重。当然server参数远不止权重一项。这个结构定义如下: ~~~cpp // In EndpointParams.h struct EndpointParams { size_t max_connections; int connect_timeout; int response_timeout; int ssl_connect_timeout; bool use_tls_sni; }; // In ServiceGovernance.h struct AddressParams { struct EndpointParams endpoint_params; ///< Connection config unsigned int dns_ttl_default; ///< in seconds, DNS TTL when network request success unsigned int dns_ttl_min; ///< in seconds, DNS TTL when network request fail /** * - The max_fails directive sets the number of consecutive unsuccessful attempts to communicate with the server. * - After 30s following the server failure, upstream probe the server with some live client’s requests. * - If the probes have been successful, the server is marked as a live one. * - If max_fails is set to 1, it means server would out of upstream selection in 30 seconds when failed only once */ unsigned int max_fails; ///< [1, INT32_MAX] max_fails = 0 means max_fails = 1 unsigned short weight; ///< [1, 65535] weight = 0 means weight = 1. only for main server int server_type; ///< 0 for main and 1 for backup int group_id; ///< -1 means no group. Backup without group will backup for any main node }; ~~~ 大多数参数的作用一眼了然。其中endpoint_params和dns相关参数,可以覆盖全局的配置。 例如,全局对每个目标ip最大连接数为200,但我想为10.135.35.53设置最多1000连接数,可以这么做: ~~~cpp UpstreamManager::upstream_create_weighted_random("10.135.35.53", false); struct AddressParams params = ADDRESS_PARAMS_DEFAULT; params.endpoint_params.max_connections = 1000; UpstreamManager::upstream_add_server("10.135.35.53", "10.135.35.53", ¶ms); ~~~ max_fails参数为最大出错次数,如果选取目标连续出错达到max_fails则熔断,如果upstream的try_another属性为false,则任务失败, 在任务callback里,get_state()=WFT_STATE_TASK_ERROR,get_error()=WFT_ERR_UPSTREAM_UNAVAILABLE。 如果try_another为true,并且所有server都熔断的话,会得到同样错误。熔断时间为30秒。 server_type和group_id用于主备功能。所有upstream必需有type为0(主节点)的server,否则upstream不可用。 类型为1(备份节点)的server,会在同group_id的主节点熔断情况下被使用。 更多upstream功能查询:[about-upstream.md](./about-upstream.md)。 workflow-0.11.8/docs/about-timeout.md000066400000000000000000000200011476003635400175240ustar00rootroot00000000000000# 关于超时 为了让所有通信任务可以在用户的预期下精确运行,我们提供了大量的超时配置功能,并且确保这些超时的准确性。 这些超时配置里,有些是全局的,比如连接超时,但你又可以通过upstream功能,给某个域名配置自己的连接超时。 有一些超时是任务级的,比如完整发送一条消息的超时。因为用户需要根据消息大小,动态配置这个值。 当然对server来讲,又有自己的超时整体配置。总之,超时是一件很复杂的事,我们会做得很精确。 所有超时都采用poll风格,也就是int型,毫秒级,-1表示无限。 另外,正如我们在项目介绍里说的,所有的配置你都可以忽略,可以等遇到实际需求了再进行调整。 ### 基础通信超时配置 在[EndpointParams.h](../src/manager/EndpointParams.h)文件里,可以看到: ~~~cpp struct EndpointParams { size_t max_connections; int connect_timeout; int response_timeout; int ssl_connect_timeout; }; static constexpr struct EndpointParams ENDPOINT_PARAMS_DEFAULT = { .max_connections = 200, .connect_timeout = 10 * 1000, .response_timeout = 10 * 1000, .ssl_connect_timeout = 10 * 1000, }; ~~~ 其中,与超时相关的配置包括以下3项。 * connect_timeout: 与目标建立连接的超时。默认为10秒。 * response_timeout: 等待目标响应的超时,默认为10秒。代表成功发送到目标、或从目标读取到一块数据的超时。 * ssl_connect_timeout: 与目标完成SSL握手的超时。默认为10秒。 这个结构体是通信连接的最基础的配置,后续几乎所有的通信配置都会含有这个结构体。 ### 全局超时配置 在[WFGlobal.h](../src/manager/WFGlobal.h)文件里,可以看到我们一个全局配置信息: ~~~cpp struct WFGlobalSettings { EndpointParams endpoint_params; unsigned int dns_ttl_default; unsigned int dns_ttl_min; int dns_threads; int poller_threads; int handler_threads; int compute_threads; }; static constexpr struct WFGlobalSettings GLOBAL_SETTINGS_DEFAULT = { .endpoint_params = ENDPOINT_PARAMS_DEFAULT, .dns_ttl_default = 12 * 3600, /* in seconds */ .dns_ttl_min = 180, /* reacquire when communication error */ .dns_threads = 8, .poller_threads = 2, .handler_threads = 20, .compute_threads = -1 }; //compute_threads<=0 means auto-set by system cpu number ~~~ 其中,与超时相关的配置就是EndpointParams endpoint_params这一项 修改全局配置的方法是,调用我们任何工厂函数之前,执行类似下面的操作: ~~~cpp int main() { struct WFGlobalSettings settings = GLOBAL_SETTINGS_DEFAULT; settings.endpoint_params.connect_timeout = 2 * 1000; settings.endpoint_params.response_timeout = -1; WORKFLOW_library_init(&settings); } ~~~ 上例把连接超时修改为2秒,远程服务器响应超时为无限。这种配置下,每次任务里都必须配置接收完整消息的超时,否则可能陷入无限的等待。 全局的超时配置,可以通过upstream功能,被单独的地址配置覆盖,比如你可以指定某个域名的连接超时。 Upstream每一个AddressParams也有一个EndpointParams endpoint_params项,使用方式与Global相仿。 具体结构详见[upstream文档](tutorial-10-upstream.md#Address属性) ### Server超时配置 在[http_proxy](./tutorial-05-http_proxy.md)示例的里,我们介绍过server启动配置。其中超时相关的配置包括: * peer_response_timeout: 这个的定义和全局的response_timeout一样,指的是远程client的响应超时,默认为10秒。 * receive_timeout: 接收一条完整请求的超时,默认为-1。 * keep_alive_timeout: 连接保持时间。默认1分钟。redis server为5分钟。 * ssl_accept_timeout: 完成ssl握手的超时,默认为10秒。 在这个默认配置下,client可以每9秒发送一个字节,让server一直接收而不引起超时。所以,如果服务用于公网,需要配置receive_timeout。 ### 任务级别的超时配置 任务级别的超时配置通过网络任务的几个接口调用来完成: ~~~cpp template class WFNetworkTask : public CommRequest { ... public: /* All in milliseconds. timeout == -1 for unlimited. */ void set_send_timeout(int timeout) { this->send_timeo = timeout; } void set_receive_timeout(int timeout) { this->receive_timeo = timeout; } void set_keep_alive(int timeout) { this->keep_alive_timeo = timeout; } void set_watch_timeout(int timeout) { this->watch_timeo = timeout; } ... } ~~~ 其中,set_send_timeout()设置发送完整消息的超时,默认值为-1。 set_receive_timeout()只对client任务有效,指接收完整server回复的超时,默认值为-1。 * server任务的receive_timeout在server启动配置里。对server任务设置receive_timeout没有意义,因为消息已经接收完成。 set_keep_alive()接口设置连接保持超时。一般来讲,框架能很好的处理连接保持的问题,用户不需要调用。 如果是http协议,client或server想要使用短连接,可通过添加HTTP header来完成,尽量不要用这个接口去修改。 如果一个redis client想要在请求之后关闭连接,则需要用这个接口。显然,在callback里set_keep_alive()是无效的(连接已经被复用)。 set_watch_timeout()接口为client任务专有,代表一个client任务的请求发出之后,接收到第一个返回包的最大等待时间。 利用watch timeout,可以避免一些需要等待数据推送的client任务受到response timeout和receive timeout的约束而超时。 设置了watch timeout之后,从接收到第一个数据包再开始计算receive timeout。 ### 任务的同步等待超时 有一个非常特殊的超时配置,是全局唯一一个同步等待超时。我们并不鼓励使用,但在某些应用场景下能得到很好的效果。 目前框架里,目标服务器是有连接上限的(全局和upstream都可以配置)。如果连接已经达到上限,默认的情况下,client任务失败返回。 callback里task->get_state()得到WFT_STATE_SYS_ERROR, task->get_error()得到EAGAIN。如果任务配置了retry,会自动发起重试。 在这里,我们允许通过task->set_wait_timeout()接口,配置一个同步等待超时,如果在这段时间内,有连接被释放,则任务可以占用这个连接。 如果用户配置了wait_timeout,并且在超时之前没有拿到连接,则callback得到WFT_STATE_SYS_ERROR状态和ETIMEDOUT错误。 ~~~cpp class CommRequest : public SubTask, public CommSession { public: ... void set_wait_timeout(int wait_timeout) { this->wait_timeout = wait_timeout; } } ~~~ ### 超时的原因查看 通信task包含一个get_timeout_reason()接口,用于返回超时原因,但不是很细致,包括以下几个返回值: * TOR_NOT_TIMEOUT: 不是超时。 * TOR_WAIT_TIMEOUT: 同步等待超时。 * TOR_CONNECT_TIMEOUT: 连接超时。包括TCP,SCTP等协议的连接和SSL连接超时,都是这个值。 * TOR_TRANSMIT_TIMEOUT: 一切传输超时。不能进一步区分是发送阶段还是接收阶段。以后可能会细化。 * server任务,超时原因一定是TRANSMIT_TIMEOUT,并且一定是发送回复的阶段。 ### 超时功能的实现 框架内部,需要处理的超时种类比我们在这里展现的还要更多。除了wait_timeout,全都是依赖于Linux的timerfd或kqueue的timer事件。 每个poller线程包含一个timerfd,默认配置下,poller线程数为4,可以满足大多数应用的需要了。 目前的超时算法利用了链表+红黑树的数据结构,时间复杂度在O(1)和O(logn)之间,其中n为poller线程的fd数量。 超时处理目前看不是瓶颈所在,因为Linux内核epoll相关调用也是O(logn)时间复杂度,我们把超时都做到O(1)也区别不大。 workflow-0.11.8/docs/about-timer.md000066400000000000000000000110241476003635400171630ustar00rootroot00000000000000# 关于定时器 定时器的作用是不占线程的等待一个确定时间,同样通过callback来通知定时器到期。 # 定时器的创建 WFTaskFactory类里包括四个定时相关的接口: ~~~cpp using timer_callback_t = std::function; class WFTaskFactory { ... public: static WFTimerTask *create_timer_task(time_t seconds, long nanoseconds, timer_callback_t callback); static WFTimerTask *create_timer_task(const std::string& timer_name, time_t seconds, long nanoseconds, timer_callback_t callback); static int cancel_by_name(const std::string& timer_name) { cancel_by_name(const std::string& timer_name, (size_t)-1); } static int cancel_by_name(const std::string& timer_name, size_t max); }; ~~~ 我们通过seconds和nanoseconds两个参数来指定一个定时器的定时时间。其中,seconds指定秒数而nanoseconds为纳秒数。 * seconds参数可以传递-1,产生一个无限时长的定时器,一般用于命名定时器,为了将来调用cancel取消定时。 * nanoseconds的取值范围在[0,1000000000),否则timer运行之后会立刻错误返回,错误码为EINVAL。 在创建定时器任务时,可以传入一个timer_name作为定时器名,用于cancel_by_name接口取消定时。 定时器也是一种任务,因此使用方式与其它类型任务无异,同样有user_data域可以利用。 # 取消定时 如果在创建定时器任务时传入一个名称,那么这个定时器就可以在被提前中断。 中断一个定时任务的方法是通过WFTaskFactory::cancel_by_name这个接口,这个接口默认情况下,会取消这个名称下的所有定时器。 因此,我们也支持传入一个max参数,让操作最多取消max个定时器。无论哪个接口,返回值都是代表实际被取消的定时器个数。 如果没有这个名称下的定时器,cancel操作不会产生任何效果,并返回0。 定时器在被创建之后就可取消,并非一定要等它被启动之后。以这个代码为例: ~~~cpp #include #include "workflow/WFTaskFactory.h" int main() { WFTimerTask *timer = WFTaskFactory::create_timer_task("test", 10000, 0, [](WFTimerTask *){ printf("timer callback, state = %d, error = %d.\n", task->get_state(), task->get_error()); }); WFTaskFactory::cancel_by_name("test"); timer->start(); getchar(); return 0; } ~~~ 程序会在立即打印出'timer callback, state = 1, error = 125.",因为定时器在运行之前就已经被取消了。所以,定时任务启动后立即callback,状态码为WFT_STATE_SYS_ERROR,错误码为ECANCELED。 使用中需要注意的是,命名定时器比匿名定时器是会多出一些开销的,原因是我们需要维护查找表,会有加锁解锁等操作。如果你的定时器没有提前中断的需要,就不要在创建时传入timer_name了。 # 程序退出打断定时器 在[关于程序退出](./about-exit.md)里讲到,main函数结束或exit()被调用的时候,所有任务必须里运行到callback,并且没有新的任务被调起。 这时就可能出现一个问题,定时器的定时周期可以非常长,如果是不可中断的定时器,那么等待定时器到期,程序退出需要很长时间。 而实现上,程序退出是可以打断定时器,让定时器回到callback的。如果定时器被程序退出打断,get_state()会得到一个WFT_STATE_ABORTED状态。 当然如果定时器被程序退出打断,则不能再调起新的任务。 以下这个程序,每间隔一秒抓取一个一个http页面。当所有url抓完毕,程序直接退出,不用等待timer回到callback,退出不会有延迟。 ~~~cpp bool program_terminate = false; void timer_callback(WFTimerTask *timer) { mutex.lock(); if (!program_terminate) { WFHttpTask *task; if (urls_to_fetch > 0) { task = WFTaskFactory::create_http_task(...); series_of(timer)->push_back(task); } series_of(timer)->push_back(WFTaskFactory::create_timer_task(1, 0, timer_callback)); } mutex.unlock(); } ... int main() { .... /* all urls done */ mutex.lock(); program_terminate = true; mutex.unlock(); return 0; } ~~~ 以上程序,timer_callback必须在锁里判断program_terminate条件,否则可能在程序已经结束的情况下又调起新任务。 workflow-0.11.8/docs/about-tlv-message.md000066400000000000000000000111031476003635400202700ustar00rootroot00000000000000# 关于TLV(Type-Length-Value)格式的消息 TLV消息是一种由类型,长度,内容组成的消息。由于其结构简单通用,而且方便嵌套和扩展,特别适用于定义通信消息。 为方便用户实现自定义协议,我们内置了TLV消息的支持。 # TLV消息的结构 TLV消息并没有具体规定Type和Length这两个字段占的字节数据。在我们的协议里,它们分别占4字节(网络序)。 也就是说,我们的消息有8字节的消息头,以及不超过32GB的Value内容。Type和Value域的含义我们不做规定。 # TLVMessage类 由于TLV的定义内容很少,所以[TLVMessage](/src/protocol/TLVMessage.h)需要用到的接口很少。 ~~~cpp namespace protocol { class TLVMessage : public ProtocolMessage { public: int get_type() const { return this->type; } void set_type(int type) { this->type = type; } std::string *get_value() { return &this->value; } void set_value(std::string value) { this->value = std::move(value); } protected: int type; std::string value; ... }; using TLVRequest = TLVMessage; using TLVResposne = TLVMessage; } ~~~ 用户直接使用TLV消息来做数据传输的话,只需要用到上面的几个接口。分别为设置和获取Type与Value。 Value直接以std::string返回,方便用户必要的时候直接通过std::move移动数据。 # 基于TLV消息的echo server/client 以下代码,直接启动一个基于TLV消息的server,并通过命令行产生client task进行交互。建议运行一下: ~~~cpp #include #include #include #include "workflow/WFGlobal.h" #include "workflow/WFFacilities.h" #include "workflow/TLVMessage.h" #include "workflow/WFTaskFactory.h" #include "workflow/WFServer.h" using namespace protocol; using WFTLVServer = WFServer; using WFTLVTask = WFNetworkTask; using tlv_callback_t = std::function; WFTLVTask *create_tlv_task(const char *host, unsigned short port, tlv_callback_t callback) { auto *task = WFNetworkTaskFactory::create_client_task( TT_TCP, host, port, 0, std::move(callback)); task->set_keep_alive(60 * 1000); return task; } int main() { WFTLVServer server([](WFTLVTask *task) { *task->get_resp() = std::move(*task->get_req()); }); if (server.start(8888) != 0) { perror("server.start"); exit(1); } auto&& create = [](WFRepeaterTask *)->SubTask * { std::string string; printf("Input string (Ctrl-D to exit): "); std::cin >> string; if (string.empty()) return NULL; auto *task = create_tlv_task("127.0.0.1", 8888, [](WFTLVTask *task) { if (task->get_state() == WFT_STATE_SUCCESS) printf("Server Response: %s\n", task->get_resp()->get_value()->c_str()); else { const char *str = WFGlobal::get_error_string(task->get_state(), task->get_error()); fprintf(stderr, "Error: %s\n", str); } }); task->get_req()->set_value(std::move(string)); return task; }; WFFacilities::WaitGroup wait_group(1); WFRepeaterTask *repeater = WFTaskFactory::create_repeater_task(std::move(create), nullptr); Workflow::start_series_work(repeater, [&wait_group](const SeriesWork *) { wait_group.done(); }); wait_group.wait(); server.stop(); return 0; } ~~~ # 派生TLVMessage 上面的echo server实例,我们直接使用了原始的TLVMessage。但建议在具体的应用中,用户可以对消息进行派生。 在派生类里,提供更加丰富的接口来设置和提取消息内容,避免直接操作原始Value域,并形成自己的二级协议。 例如,我们实现一个JSON的协议,可以: ~~~cpp #include "workflow/json-parser.h" // 内置的json解析器 class JsonMessage : public TLVMessage { public: void set_json_value(const json_value_t *val) { this->type = JSON_TYPE; this->json_to_string(val, &this->value); // 需要实现一下 } json_value_t *get_json_value() const { if (this->type == JSON_TYPE) return json_parser_parse(this->value.c_str()); // json-parser的函数 else return NULL; } }; using JsonRequest = JsonMessage; using JsonResponse = JsonMessage; using JsonServer = WFServer; ~~~ 这个例子只是为了说明派生的重要性,实际应用中,派生类可能要远远比这个复杂。 workflow-0.11.8/docs/about-upstream.md000066400000000000000000000473331476003635400177170ustar00rootroot00000000000000# 关于Upstream 在nginx里,Upstream代表了反向代理的负载均衡配置。在这里,我们扩充Upstream的含义,让其具备以下几个特点: 1. 每一个Upstream都是一个独立的反向代理 2. 访问一个Upstream等价于,在一组服务/目标/上下游,使用合适的策略选择其中一个进行访问 3. Upstream具备负载均衡、出错处理、熔断和其他服务治理能力 4. 对于同一个请求的多次重试,Upstream可以避开已试过的目标 5. 通过Upstream可以对不同下游配置不同的连接参数 6. 动态增删目标地址实时生效,方便对接任意的服务发现系统 ### Upstream相对于域名DNS解析的优势 Upstream和域名DNS解析都可以将一组ip配置到一个Host,但是 1. DNS域名解析是不针对于端口号的,相同ip不同端口的服务DNS域名是不能配置到一起的;但Upstream可以 2. DNS域名解析对应的一组address,必定是ip;Upstream对应的一组address,可以是ip、域名或unix-domain-socket 3. 通常情况下,DNS域名解析会被操作系统或网络上DNS服务器所缓存,更新时间受到ttl的限制;Upstream可以做到实时更新实时生效 4. DNS域名解析消耗比Upstream解析和选取大很多 ### Workflow的Upstream 这是一个本地反向代理模块,代理配置对server和client都生效。 支持动态配置,可用于服务发现系统,目前[workflow-k8s](https://github.com/sogou/workflow-k8s)可以对接Kubernetes的API Server。 Upstream名不包括端口,但Upstream请求支持指定端口(如果使用非内置协议,Upstream名暂时需要加上端口号以保证构造时的解析成功)。 每一个Upstream配置自己的独立名称UpstreamName,并添加设定着一组Address,这些Address可以是: 1. ip4 2. ip6 2. 域名 3. unix-domain-socket ### 为什么要替代nginx的Upstream #### nginx的Upstream工作方式 1. 只支持http/https协议 2. 需要搭建一个nginx服务,启动进程占用socket等其他资源 3. 请求先打到nginx上,nginx再向远端转发请求,这会多一次通信开销 #### workflow本地Upstream工作方式 1. 协议无关,你甚至可以通过upstream访问mysql、redis、mongodb等等 2. 无需额外启动其他进程或端口,直接在进程内模拟反向代理的功能 3. 选取过程是基本的计算和查表,不会有额外的通信开销 # 使用Upstream ### 常用接口 ~~~cpp class UpstreamManager { public: static int upstream_create_consistent_hash(const std::string& name, upstream_route_t consitent_hash); static int upstream_create_weighted_random(const std::string& name, bool try_another); static int upstream_create_manual(const std::string& name, upstream_route_t select, bool try_another, upstream_route_t consitent_hash); static int upstream_create_vnswrr(const std::string& name); static int upstream_delete(const std::string& name); public: static int upstream_add_server(const std::string& name, const std::string& address); static int upstream_add_server(const std::string& name, const std::string& address, const struct AddressParams *address_params); static int upstream_remove_server(const std::string& name, const std::string& address); ... } ~~~ ### 例1 在多个目标中随机访问 配置一个本地反向代理,将本地发出的my_proxy.name所有请求均匀的打到6个目标server上 ~~~cpp UpstreamManager::upstream_create_weighted_random( "my_proxy.name", true);//如果遇到熔断机器,再次尝试直至找到可用或全部熔断 UpstreamManager::upstream_add_server("my_proxy.name", "192.168.2.100:8081"); UpstreamManager::upstream_add_server("my_proxy.name", "192.168.2.100:8082"); UpstreamManager::upstream_add_server("my_proxy.name", "192.168.10.10"); UpstreamManager::upstream_add_server("my_proxy.name", "test.sogou.com:8080"); UpstreamManager::upstream_add_server("my_proxy.name", "abc.sogou.com"); UpstreamManager::upstream_add_server("my_proxy.name", "abc.sogou.com"); UpstreamManager::upstream_add_server("my_proxy.name", "/dev/unix_domain_scoket_sample"); auto *http_task = WFTaskFactory::create_http_task("http://my_proxy.name/somepath?a=10", 0, 0, nullptr); http_task->start(); ~~~ 基本原理 1. 随机选择一个目标 2. 如果try_another配置为true,那么将在所有存活的目标中随机选择一个 3. 仅在main中选择,选中目标所在group的主备和无group的备都视为有效的可选对象 ### 例2 在多个目标中按照权重大小随机访问 配置一个本地反向代理,将本地发出的weighted.random所有请求按照5/20/1的权重分配打到3个目标server上 ~~~cpp UpstreamManager::upstream_create_weighted_random( "weighted.random", false);//如果遇到熔断机器,不再尝试,这种情况下此次请求必定失败 AddressParams address_params = ADDRESS_PARAMS_DEFAULT; address_params.weight = 5;//权重为5 UpstreamManager::upstream_add_server("weighted.random", "192.168.2.100:8081", &address_params);//权重5 address_params.weight = 20;//权重为20 UpstreamManager::upstream_add_server("weighted.random", "192.168.2.100:8082", &address_params);//权重20 UpstreamManager::upstream_add_server("weighted.random", "abc.sogou.com");//权重1 auto *http_task = WFTaskFactory::create_http_task("http://weighted.random:9090", 0, 0, nullptr); http_task->start(); ~~~ 基本原理 1. 按照权重分配,随机选择一个目标,权重越大概率越大 2. 如果try_another配置为true,那么将在所有存活的目标中按照权重分配随机选择一个 3. 仅在main中选择,选中目标所在group的主备和无group的备都视为有效的可选对象 ### 例3 在多个目标中按照框架默认的一致性哈希访问 ~~~cpp UpstreamManager::upstream_create_consistent_hash( "abc.local", nullptr);//nullptr代表使用框架默认的一致性哈希函数 UpstreamManager::upstream_add_server("abc.local", "192.168.2.100:8081"); UpstreamManager::upstream_add_server("abc.local", "192.168.2.100:8082"); UpstreamManager::upstream_add_server("abc.local", "192.168.10.10"); UpstreamManager::upstream_add_server("abc.local", "test.sogou.com:8080"); UpstreamManager::upstream_add_server("abc.local", "abc.sogou.com"); auto *http_task = WFTaskFactory::create_http_task("http://abc.local/service/method", 0, 0, nullptr); http_task->start(); ~~~ 基本原理 1. 每1个main视为16个虚拟节点 2. 框架会使用std::hash对所有节点的address+虚拟index+此address加到此upstream的次数进行运算,作为一致性哈希的node值 3. 框架会使用std::hash对path+query+fragment进行运算,作为一致性哈希data值 4. 每次都选择存活node最近的值作为目标 5. 对于每一个main、只要有存活group内main/有存活group内backup/有存活no group backup,即视为存活 6. 如果upstream_add_server()时加上AddressParams,并配上权重weight,则每1个main视为16 * weight个虚拟节点,适用于带权一致性哈希或者希望一致性哈希标准差更小的场景 ### 例4 自定义一致性哈希函数 ~~~cpp UpstreamManager::upstream_create_consistent_hash( "abc.local", [](const char *path, const char *query, const char *fragment) -> unsigned int { unsigned int hash = 0; while (*path) hash = (hash * 131) + (*path++); while (*query) hash = (hash * 131) + (*query++); while (*fragment) hash = (hash * 131) + (*fragment++); return hash; }); UpstreamManager::upstream_add_server("abc.local", "192.168.2.100:8081"); UpstreamManager::upstream_add_server("abc.local", "192.168.2.100:8082"); UpstreamManager::upstream_add_server("abc.local", "192.168.10.10"); UpstreamManager::upstream_add_server("abc.local", "test.sogou.com:8080"); UpstreamManager::upstream_add_server("abc.local", "abc.sogou.com"); auto *http_task = WFTaskFactory::create_http_task("http://abc.local/sompath?a=1#flag100", 0, 0, nullptr); http_task->start(); ~~~ 基本原理 1. 框架会使用用户自定义的一致性哈希函数作为data值 2. 其余与上例原理一致 ### 例5 自定义选取策略 ~~~cpp UpstreamManager::upstream_create_manual( "xyz.cdn", [](const char *path, const char *query, const char *fragment) -> unsigned int { return atoi(fragment); }, true,//如果选择到已经熔断的目标,将进行二次选取 nullptr);//nullptr代表二次选取时使用框架默认的一致性哈希函数 UpstreamManager::upstream_add_server("xyz.cdn", "192.168.2.100:8081"); UpstreamManager::upstream_add_server("xyz.cdn", "192.168.2.100:8082"); UpstreamManager::upstream_add_server("xyz.cdn", "192.168.10.10"); UpstreamManager::upstream_add_server("xyz.cdn", "test.sogou.com:8080"); UpstreamManager::upstream_add_server("xyz.cdn", "abc.sogou.com"); auto *http_task = WFTaskFactory::create_http_task("http://xyz.cdn/sompath?key=somename#3", 0, 0, nullptr); http_task->start(); ~~~ 基本原理 1. 框架首先依据用户提供的普通选取函数、按照取模,在main列表中确定选取 2. 对于每一个main、只要有存活group内main/有存活group内backup/有存活no group backup,即视为存活 3. 如果选中目标不再存活且try_another设为true,将再使用一致性哈希函数进行二次选取 4. 如果触发二次选取,一致性哈希将保证一定会选择一个存活目标、除非全部机器都被熔断掉 ### 例6 简单的主备模式 ~~~cpp UpstreamManager::upstream_create_weighted_random( "simple.name", true);//一主一备这项设什么没区别 AddressParams address_params = ADDRESS_PARAMS_DEFAULT; address_params.server_type = 0; UpstreamManager::upstream_add_server("simple.name", "main01.test.ted.bj.sogou", &address_params);//主 address_params.server_type = 1; UpstreamManager::upstream_add_server("simple.name", "backup01.test.ted.gd.sogou", &address_params);//备 auto *http_task = WFTaskFactory::create_http_task("http://simple.name/request", 0, 0, nullptr); auto *redis_task = WFTaskFactory::create_redis_task("redis://simple.name/2", 0, nullptr); redis_task->get_req()->set_query("MGET", {"key1", "key2", "key3", "key4"}); (*http_task * redis_task).start(); ~~~ 基本原理 1. 主备模式与前面所展示的任何模式都不冲突,可以同时生效 2. 主备数量各自独立,没有限制。主和主之间平等,备与备之间平等,主备之间不平等。 3. 只要有主活着,请求一直会使用某一个主 4. 如果主都被熔断,备将作为替代目标接管请求直至有主恢复正常 5. 在每一个策略中,存活的备都可以作为主的存活依据 ### 例7 主备+一致性哈希+分组 ~~~cpp UpstreamManager::upstream_create_consistent_hash( "abc.local", nullptr);//nullptr代表使用框架默认的一致性哈希函数 AddressParams address_params = ADDRESS_PARAMS_DEFAULT; address_params.server_type = 0; address_params.group_id = 1001; UpstreamManager::upstream_add_server("abc.local", "192.168.2.100:8081", &address_params);//main in group 1001 address_params.server_type = 1; address_params.group_id = 1001; UpstreamManager::upstream_add_server("abc.local", "192.168.2.100:8082", &address_params);//backup for group 1001 address_params.server_type = 0; address_params.group_id = 1002; UpstreamManager::upstream_add_server("abc.local", "main01.test.ted.bj.sogou", &address_params);//main in group 1002 address_params.server_type = 1; address_params.group_id = 1002; UpstreamManager::upstream_add_server("abc.local", "backup01.test.ted.gd.sogou", &address_params);//backup for group 1002 address_params.server_type = 1; address_params.group_id = -1; UpstreamManager::upstream_add_server("abc.local", "test.sogou.com:8080", &address_params);//backup for no group mean backup for all group and no group UpstreamManager::upstream_add_server("abc.local", "abc.sogou.com");//main, no group auto *http_task = WFTaskFactory::create_http_task("http://abc.local/service/method", 0, 0, nullptr); http_task->start(); ~~~ 基本原理 1. 组号-1代表无组,这种目标不属于任何组 2. 无组的main之间是平等的,甚至可以视为同一个组。但与有组的main之间是隔离的 3. 无组的backup可以为全局任何组目标/任何无组目标作为备 4. 组号可以区分哪些主备是在一起工作的 5. 不同组之间的备是相互隔离的,只为本组的main服务 6. 添加目标的默认组号-1,type为0,表示主节点。 ### 例8 NVSWRR平滑按权重选取策略 ~~~cpp UpstreamManager::upstream_create_vnswrr("nvswrr.random"); AddressParams address_params = ADDRESS_PARAMS_DEFAULT; address_params.weight = 3;//权重为3 UpstreamManager::upstream_add_server("nvswrr.random", "192.168.2.100:8081", &address_params);//权重3 address_params.weight = 2;//权重为2 UpstreamManager::upstream_add_server("nvswrr.random", "192.168.2.100:8082", &address_params);//权重2 UpstreamManager::upstream_add_server("nvswrr.random", "abc.sogou.com");//权重1 auto *http_task = WFTaskFactory::create_http_task("http://nvswrr.random:9090", 0, 0, nullptr); http_task->start(); ~~~ 基本原理 1. 虚拟节点初始化顺序按照[SWRR算法](https://github.com/nginx/nginx/commit/52327e0627f49dbda1e8db695e63a4b0af4448b1)选取 2. 虚拟节点运行时分批初始化,避免密集型计算集中,每批次虚拟节点使用完后再进行下一批次虚拟节点列表初始化 3. 兼具[SWRR算法](https://github.com/nginx/nginx/commit/52327e0627f49dbda1e8db695e63a4b0af4448b1)的平滑、分散特点,又能具备O(1)的时间复杂度 4. 算法具体细节参见[tengine](https://github.com/alibaba/tengine/pull/1306) # Upstream选择策略 当发起请求的url的URIHost填UpstreamName时,视做对与名字对应的Upstream发起请求,接下来将会在Upstream记录的这组Address中进行选择: 1. 权重随机策略:按照权重随机选择 2. 一致性哈希策略:框架使用标准的一致性哈希算法,用户可以自定义对请求uri的一致性哈希函数consistent_hash 3. 手动策略:根据用户提供的对请求uri的select函数进行确定的选择,如果选中了已经熔断的目标: a. 如果try_another为false,这次请求将返回失败 b. 如果try_another为true,框架使用标准的一致性哈希算法重新选取,用户可以自定义对请求uri的一致性哈希函数consistent_hash 4. 主备策略:按照先主后备的优先级,只要主可以用就选择主。此策略可以与[1]、[2]、[3]中的任何一个同时生效,相互影响。 round-robin/weighted-round-robin:视为与[1]等价,暂不提供 框架建议普通用户使用策略[2],可以保证集群具有良好的容错性和可扩展性 对于复杂需求场景,高级用户可以使用策略[3],订制复杂的选择逻辑 # Address属性 ~~~cpp struct EndpointParams { size_t max_connections; int connect_timeout; int response_timeout; int ssl_connect_timeout; bool use_tls_sni; }; static constexpr struct EndpointParams ENDPOINT_PARAMS_DEFAULT = { .max_connections = 200, .connect_timeout = 10 * 1000, .response_timeout = 10 * 1000, .ssl_connect_timeout = 10 * 1000, .use_tls_sni = false, }; struct AddressParams { struct EndpointParams endpoint_params; unsigned int dns_ttl_default; unsigned int dns_ttl_min; unsigned int max_fails; unsigned short weight; int server_type; int group_id; }; static constexpr struct AddressParams ADDRESS_PARAMS_DEFAULT = { .endpoint_params = ENDPOINT_PARAMS_DEFAULT, .dns_ttl_default = 12 * 3600, .dns_ttl_min = 180, .max_fails = 200, .weight = 1, //only for main of UPSTREAM_WEIGHTED_RANDOM .server_type = 0, .group_id = -1, }; ~~~ 每个Addreess都可以配置自己的自定义参数: * EndpointParams的max_connections, connect_timeout, response_timeout, ssl_connect_timeout:连接相关的参数 * dns_ttl_default:dns cache中默认的ttl,单位秒,默认12小时,dns cache是针对当前进程的,即进程退出就会消失,配置也仅对当前进程有效 * dns_ttl_min:dns最短生效时间,单位秒,默认3分钟,用于在通信失败重试时是否进行重新dns的决策 * max_fails:触发熔断的【连续】失败次数(注:每次通信成功,计数会清零) * weight:权重,默认1,仅对main有效,用于Upstream随机策略选取和一致性哈希选取,权重大越容易被选中 * server_type:主备配置,默认主。无论什么时刻,同组的主优先级永远高于其他的备 * group_id:分组依据,默认-1。-1代表无分组(游离),游离的备可视为任何主的备,有组的备优先级永远高于游离的备。 # 关于熔断 ## MTTR 平均修复时间(Mean time to repair,MTTR),是描述产品由故障状态转为工作状态时修理时间的平均值。 ## 服务雪崩效应 服务雪崩效应是一种因“服务提供者的故障”(原因),导致“服务调用者故障”(结果),并将不可用逐渐/逐级放大的现象 若不加以有效控制,效应不会收敛,而且会以几何级放大,犹如雪崩,雪崩效应因此得名 日常表现通常为:起初只是一个很小的服务or模块异常/超时,引起下游其他依赖的服务随之异常/超时,产生连锁反应,最终导致绝大多数甚至全部的服务陷入瘫痪 随着故障的修复,效应随之消失,所以效应持续时间通常等于MTTR ## 熔断机制 当某一个目标的错误or异常触达到预先设定的阈值条件时,暂时认为这个目标不可用,剔除目标,即熔断开启进入熔断期 在熔断持续时间达到MTTR时长后,会进入半熔断状态,(尝试)恢复目标 如果恢复的时候发现其他所有目标都被熔断,会同一时间把所有目标恢复 熔断机制策略可以有效阻止雪崩效应 ## Upstream熔断保护机制 MTTR=30秒,暂时不可配置,后续会考虑开放给用户自行配置 当某一个Addrees连续失败次数达到设定上限(默认200次),这个Address会被熔断MTTR=30秒 Address在熔断期间,一旦被策略选中,Upstream会根据具体配置决定是否尝试其他Address、如何尝试 请注意满足下面1-4的某个情景,通信任务将得到一个WFT_ERR_UPSTREAM_UNAVAILABLE = 1004的错误: 1. 权重随机策略,全部目标都处于熔断期 2. 一致性哈希策略,全部目标都处于熔断期 3. 手动策略 && try_another==true,全部目标都处于熔断期 4. 手动策略 && try_another==false,且同时满足下面三个条件: 1). select函数选中的main处于熔断期,,且游离的备都处于熔断期 2). 这个main是游离的主,或者这个main所在的group其他目标都处于熔断期 3). 所有游离的备都处于熔断期 # Upstream端口优先级 1. 优先选择显式配置在Upstream Address上的端口号 2. 若没有,再选择显式配置在请求url中的端口号 3. 若都没有,使用协议默认端口号 ~~~text 配置 UpstreamManager::upstream_add_server("my_proxy.name", "192.168.2.100:8081"); 请求 http://my_proxy.name:456/test.html => http://192.168.2.100:8081/test.html 请求 http://my_proxy.name/test.html => http://192.168.2.100:8081/test.html ~~~ ~~~text 配置 UpstreamManager::upstream_add_server("my_proxy.name", "192.168.10.10"); 请求 http://my_proxy.name:456/test.html => http://192.168.10.10:456/test.html 请求 http://my_proxy.name/test.html => http://192.168.10.10:80/test.html ~~~ workflow-0.11.8/docs/benchmark.md000066400000000000000000000154151476003635400166750ustar00rootroot00000000000000# 性能测试 Sogou C++ Workflow是一款性能优异的网络框架,本文介绍我们进行的性能测试, 包括方案、代码、结果,以及与其他同类产品的对比。 更多场景下的实验正在进行中,本文将持续更新。 ## HTTP Server HTTP Client/Server是Sogou C++ Workflow常见的应用场景, 我们首先对Server端进行实验。 ### 环境 我们部署了两台相同机器作为Server和Client,软硬件配置如下: | 软硬件 | 配置 | |:---:|:---| | CPU | 40 Cores, x86_64, Intel(R) Xeon(R) CPU E5-2630 v4 @ 2.20GHz | | Memory | 192GB | | NIC | 25000Mbps | | OS | CentOS 7.8.2003 | | Kernel | Linux version 3.10.0-1127.el7.x86_64 | | GCC | 4.8.5 | 两者间`ping`测得的RTT为0.1ms左右。 ### 对照组 我们选择nginx和brpc作为对照组。 选择前者是因为它在生产中部署十分广泛,性能不俗; 对于后者,我们在本次实验中只关注HTTP Server方面的能力, 其他的特性已有[单独的实验][Sogou RPC Benchmark]进行更为详尽的测试。 事实上,我们也对此二者之外的其他某些框架同时进行了实验, 但结果其性能表现相差较远,因此未在本文中体现。 后续我们将选取更多合适的框架加入对比测试中。 ### Client工具 本次实验我们使用的压测工具为[wrk][wrk]和[wrk2][wrk2]。 前者适合测试特定并发下的QPS极限和延时, 后者适合在特定QPS下测试延时分布。 我们也尝试过使用其他测试工具,例如[ab][ab]等,但无法打出足够的压力。 有鉴于此,我们也在着手开发基于Sogou C++ Workflow的benchmark工具。 ### 变量和指标 一般而言,对网络框架的性能测试,切入的角度可谓纷繁多样。 通过控制不同的变量、观测不同的指标,可以探究程序在不同场景下的适应能力。 本次实验,我们选择其中最普遍常见的变量和指标: 通过控制Client并发度和承载数据的大小,来测试QPS和延时的变化情况。 另外,我们还测试了在掺杂慢请求的正常请求的延时分布。 下面依次介绍两个测试场景。 ### 不同并发度和数据长度下的QPS和延时 #### 代码和配置 我们搭建了一个极其简约的HTTP服务器, 忽略掉所有的业务逻辑, 将测试点聚焦在纯粹的网络框架性能上。 代码片段如下, 完整代码移步[这里][benchmark-01 Code]。 ```cpp // ... auto * resp = task->get_resp(); resp->add_header_pair("Date", timestamp); resp->add_header_pair("Content-Type", "text/plain; charset=UTF-8"); resp->append_output_body_nocopy(content.data(), content.size()); // ... ``` 可以从上述代码中看到, 对于到来的任何HTTP请求, 我们都会返回一段固定的内容作为Body, 并设置必要的Header, 包括代码中指明的`content-type`、`date`, 以及自动填充的`connection`和`content-length`。 HTTP Body的固定内容是在Server启动时随机生成的ASCII字符串, 其长度可以通过启动参数配置。 同时可以配置的还有使用的poller线程数和监听的端口号。 前者我们在本次测试中固定为16, 因此Sogou C++ Workflow将使用16个poller线程和20个handler线程(默认配置)。 对于nginx和brpc, 我们也构建了相同的返回内容, 并为nginx配置了40个进程、 brpc配置了40个线程。 #### 变量 我们控制并发度在`[1, 2K]`之间翻倍增长, 数据长度在`[16B, 64KB]`之间翻倍增长, 两者正交。 #### 指标 鉴于并发度和数据长度组合之后数量较多, 我们选择其中部分数据绘制为曲线。 ##### 固定数据长度下QPS与并发度关系 ![Concurrency and QPS][Con-QPS] 上图可以看出,当数据长度保持不变, QPS随着并发度提高而增大,后趋于平稳。 此过程中Sogou C++ Workflow一直有明显优势, 高于brpc和nginx。 特别是数据长度为64和512的两条曲线, 并发度足够的时候,可以保持500K的QPS。 注意上图中nginx-64与nginx-512的曲线重叠度很高, 不易辨识。 ##### 固定并发度下QPS与数据长度关系 ![Body Length and QPS][Len-QPS] 上图可以看出,当并发度保持不变, 随着数据长度的增长, QPS保持平稳至4K时下降。 此过程中,Sogou C++ Workflow也一直保持优势。 ##### 固定数据长度下延时与并发度关系 ![Concurrency and Latency][Con-Lat] 上图可以看出,保持数据长度不变, 延时随并发度提高而有所上升。 此过程中,Sogou C++ Workflow略好于brpc, 大好于nginx。 ##### 固定并发度下延时与数据长度关系 ![Body Length and Latency][Len-Lat] 上图可以看出,并发度保持不变时, 增大数据长度,造成延时上升。 此过程中,Sogou C++ Workflow好于nginx, 好于brpc。 ### 掺杂慢请求的延时分布 #### 代码 我们在上一个测试的基础上,简单添加了一个慢请求的逻辑, 模拟业务场景中可能出现的特殊情况。 代码片段如下, 完整代码请移步[这里][benchmark-02 Code]。 ```cpp // ... if (std::strcmp(uri, "/long_req/") == 0) { auto timer_task = WFTaskFactory::create_timer_task(microseconds, nullptr); series_of(task)->push_back(timer_task); } // ... ``` 我们在Server的process里进行判断, 如果访问的是特定的路径, 则添加一个`WFTimerTask`到Series的末尾, 能够模拟一个异步耗时处理过程。 类似地,对brpc使用`bthread_usleep()`函数进行异步睡眠。 #### 配置 在本次实验中,我们固定并发度为1024,数据长度为1024字节, 分别以QPS为20K、100K和200K进行正常请求测试, 测绘延时; 与此同时,有另一路压力,进行慢请求, QPS是上述QPS的1%, 数据不计入统计。 慢请求的时长固定为5ms。 #### 延时CDF图 ![Latency CDF][Lat CDF] 从上图可以看出,当QPS为20K时, Sogou C++ Workflow略次于brpc; 当QPS为100K时,两者几乎相当; 当QPS为200K时,Sogou C++ Workflow略好于brpc。 总之,可以认为两者在这方面旗鼓相当。 [Sogou RPC Benchmark]: https://github.com/holmes1412/sogou-rpc-benchmark [wrk]: https://github.com/wg/wrk [wrk2]: https://github.com/giltene/wrk2 [ab]: https://httpd.apache.org/docs/2.4/programs/ab.html [benchmark-01 Code]: ../benchmark/benchmark-01-http_server.cc [benchmark-02 Code]: ../benchmark/benchmark-02-http_server_long_req.cc [Con-QPS]: https://raw.githubusercontent.com/wiki/sogou/workflow/img/benchmark-01.png [Len-QPS]: https://raw.githubusercontent.com/wiki/sogou/workflow/img/benchmark-02.png [Con-Lat]: https://raw.githubusercontent.com/wiki/sogou/workflow/img/benchmark-03.png [Len-Lat]: https://raw.githubusercontent.com/wiki/sogou/workflow/img/benchmark-04.png [Lat CDF]: https://raw.githubusercontent.com/wiki/sogou/workflow/img/benchmark-05.png workflow-0.11.8/docs/bugs.md000066400000000000000000000037771476003635400157130ustar00rootroot00000000000000# 已知BUG列表 ### OpenSSL 1.1.1及以下,出现网络任务状态为WFT_STATE_SYS_ERROR,错误为0。 这是OpenSSL 1.1.1及以下的bug,在SSL_get_error()为SSL_ERROR_SYSCALL时,errno被置为0。由于框架会把SSL_ERROR_SYSCALL转为系统错误,这会导致我们得到一个错误码0的系统错误: ~~~cpp void callback(WFHttpTask *task) { int state = task->get_state(); int error = task->get_error(); printf("%d, %d\n", state, error); // 此处得到1,0,其中1是WFT_STATE_SYS_ERROR。 } ~~~ 显然只有在SSL通信下可能出现在这个问题。这个bug在OpenSSL 3.0里被修复,建议升级到OpenSSL 3.0或以上。 相关issue:https://github.com/openssl/openssl/issues/12416 ### 访问HTTPS网页,当打开TLS SNI并使用upstream时出现SSL error。 当我们创建Http任务,http header里的Host域填写的是原始URL里的host部分。例如: ~~~cpp void f() { auto *task = WFTaskFactory::create_http_task("https://sogou/index.html", 0, 0, nullptr); } ~~~ 这时候http request里的Host必然填写的是"sogou"。此时如果sogou是一个upstream名,指向域名www.sogou.com。并且我们开启了TLS SNI,那么SNI server name信息就是www.sogou.com,与http header里的Host是不一致的,会导致SSL错误。 要解决这个问题,用户可以在通过设置prepare函数,在发送请求前修改Host,让它与最终URL里的一致: ~~~cpp void f(); { auto *task = WFTaskFactory::create_http_task("https://sogou/index.html", 0, 0, nullptr); task->set_prepare([](WFHttpTask *task){ auto *t = static_cast *>(task); task->get_req()->set_header_pair("Host", t->get_current_uri()->host); // 这里得到实际uri里的host。 }); } ~~~ 只有打开了TLS SNI功能并使用upstream会出这个不一致问题。当然,很多时候我们配置upstream来访问http网站,也需要做这个修改,否则对方可能不会接受你的Host信息。 workflow-0.11.8/docs/en/000077500000000000000000000000001476003635400150155ustar00rootroot00000000000000workflow-0.11.8/docs/en/CONTRIBUTING.md000066400000000000000000000100271476003635400172460ustar00rootroot00000000000000# Contribution Guide Sogou C++ Workflow is community-driven and welcomes any contributor. This document outlines some conventions about development steps, commit message formatting and contact points to make it easier to get your contribution accepted. - [Code of Conduct](#code-of-conduct) - [Getting started](#getting-started) - [First Contribution](#first-contribution) - [Find a good first topic](#find-a-good-first-topic) - [Work on an existed issue](#work-on-an-existed-issue) - [File a new issue](#file-a-new-issue) - [Contributor workflow](#contributor-workflow) - [Creating Pull Requests](#creating-pull-requests) - [Code Review](#code-review) - [Testing and building](#testing-and-building) # Code of Conduct Please make sure to read and observe our [Code of Conduct](/CODE_OF_CONDUCT.md). # Getting started - Fork the repository on GitHub. - Make your changes on your fork repository. - Submit a PR. # First Contribution We will help you to contribute in different areas like filing issues, developing features, fixing critical bugs and getting your work reviewed and merged. If you have questions about the development process, feel free to [file an issue](https://github.com/sogou/workflow/issues/new/choose). We are always in need of help, be it fixing documentation, reporting bugs or writing some code. Look at places where you feel best coding practices aren't followed, code refactoring is needed or tests are missing. Here is how you get started. ### Find a good first topic You can start by finding an existing issue with the [help-wanted](https://github.com/sogou/workflow/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22) and [good first issue](https://github.com/sogou/workflow/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) label in this repository. These issues are well suited for new contributors as a beginner-friendly issues. We can help new contributors who wish to work on such issues. Another good way to contribute is to find a documentation improvement, such as a missing/broken link. #### Work on an existed issue When you are willing to take on an issue, just reply on the issue. The maintainer will assign it to you. ### File a new issue While we encourage everyone to contribute code, it is also appreciated when someone reports an issue. Please follow the prompted submission guidelines while opening an issue. # Contributor workflow To contribute to the code base, please follow the workflow as defined in this section. 1. Create a topic branch from where you want to base your work. This is usually master. 2. Make commits of logical units and add test case if the change fixes a bug or adds new functionality. 3. Run tests and make sure all the tests are passed. 4. Make sure your commit messages are in the proper format. 5. Push your changes to a topic branch in your fork of the repository. 6. Submit a pull request. This is a rough outline of what a contributor's workflow looks like. For more details, you are encouraged to communicate with the reviewers before sending a pull request. Thanks for your contributions! ## Creating Pull Requests Our project generally follows the standard [github pull request](https://help.github.com/articles/about-pull-requests/) process. To submit a proposed change, please develop the code/fix and add new test cases. After that, run these local verifications before submitting pull request to predict the pass or fail of continuous integration. ## Code Review To make it easier for your Pull Request to receive reviews, break large changes into a logical series of smaller patches which individually make easily understandable changes, and in aggregate solve a broader issue. If this is an independent modification, then it is recommended that you provide a tutorial and corresponding documents, and communicate with us. ## Testing and building Make sure the the [travis-ci](https://travis-ci.com/github/sogou/workflow/pull_requests) passed. Once Your PR has been merged, you become a contributor. Thank you for your contribution! workflow-0.11.8/docs/en/about-config.md000066400000000000000000000105421476003635400177160ustar00rootroot00000000000000# About global configuration Global configuration is used to configure default global parameters to meet the actual business requirements and improve performance. The change of the global configuration must be made before you call any intefaces in the framework. Otherwise the change may not take effect. In addition, some global configuration items can be overridden in the upstream configuration. Please see upstream documents for reference. # Changing default configuration [WFGlobal.h](/src/manager/WFGlobal.h) defines the struts and the default values of the global configuration. ~~~cpp struct WFGlobalSettings { struct EndpointParams endpoint_params; struct EndpointParams dns_server_params; unsigned int dns_ttl_default; ///< in seconds, DNS TTL when network request success unsigned int dns_ttl_min; ///< in seconds, DNS TTL when network request fail int dns_threads; int poller_threads; int handler_threads; int compute_threads; ///< auto-set by system CPU number if value<=0 int fio_max_events; const char *resolv_conf_path; const char *hosts_path; }; static constexpr struct WFGlobalSettings GLOBAL_SETTINGS_DEFAULT = { .endpoint_params = ENDPOINT_PARAMS_DEFAULT, .dns_server_params = ENDPOINT_PARAMS_DEFAULT, .dns_ttl_default = 12 * 3600, .dns_ttl_min = 180, .dns_threads = 4, .poller_threads = 4, .handler_threads = 20, .compute_threads = -1, .fio_max_events = 4096, .resolv_conf_path = "/etc/resolv.conf", .hosts_path = "/etc/hosts", }; ~~~ [EndpointParams.h](/src/manager/EndpointParams.h) defines the struture of EndpointParams and the default values. ~~~cpp struct EndpointParams { size_t max_connections; int connect_timeout; int response_timeout; int ssl_connect_timeout; bool use_tls_sni; }; static constexpr struct EndpointParams ENDPOINT_PARAMS_DEFAULT = { .max_connections = 200, .connect_timeout = 10 * 1000, .response_timeout = 10 * 1000, .ssl_connect_timeout = 10 * 1000, .use_tls_sni = false, }; ~~~ If you want to change the default connecting timeout to 5 seconds, the default TTL for DNS to 1 hour and increase the number of poller threads for message deserialization to 10, you can follow the example below: ~~~cpp #include "workflow/WFGlobal.h" int main() { struct WFGlobalSettings settings = GLOBAL_SETTINGS_DEFAULT; settings.endpoint_params.connect_timeout = 5 * 1000; settings.dns_ttl_default = 3600; settings.poller_threads = 10; WORKFLOW_library_init(&settings); ... } ~~~ Most of the parameters are self-explanatory. Note: the ttl and related parameters in DNS configuration are in **seconds**. The timeout for endpoint is in **milliseconds**, and -1 indicates an infinite timeout. dns\_threads indicates the total number of threads accessing DNS in parallel, but by default, we use asynchronous DNS resolving and don't create any dns threads (Except windows platform). dns\_server\_params indicates parameters that we access DNS server, including the maximum cocurrent connections, and the DNS server's connecting and response timeout. compute\_threads indicates the number of threads used for computation. The default value is -1, meaning the number of threads is the same as the number of CPU cores in the current node. fio\_max\_events indicates the maximum number of concurrent asynchronous file IO events. resolv\_conf\_path indicates the path of dns resolving configuration file. The default value is "/etc/resolv.conf" on unix platforms and NULL on windows. On the windows platform, we still use multi-threaded dns resolving by default. hosts_path indicates the path of the **hosts** file. The default value is "/etc/hosts" on unix platforms. If resolv_conf_path is NULL, this configuration will be ignored. poller\_threads and handler\_threads are the two parameters for tuning network performance: * poller\_threads is mainly used for epoll (kqueue) and message deserialization. * handler\_threads is the number of threads for the callback and the process of a network task. All resources required by the framework are applied for when they are used for the first time. For example, if a user task does not involve DNS resolution, the asynchronous DNS resolver or DNS threads will not be created. workflow-0.11.8/docs/en/about-connection-context.md000066400000000000000000000153341476003635400222760ustar00rootroot00000000000000# About connection context Connection context is an advanced programming topic in this framework. From the previous examples, we can see that we cannot assign one specific connection for a client task or a server task. However, in some business scenarios, especially for the server, we may need to maintain the connection status. In other words, we need to bind a context to a connection. In the framework, we provide a connection context for users. # Application senarios for connection context HTTP is a completely stateless protocol, and HTTP session is realized with cookies. HTTP, Kafka and other stateless protocols are most friendly with our framework. The connection used by Redis and MySQL is obviously stateful. Redis specifies the database ID on the current connection with the SELECT command. MySQL uses a completely stateful connection. When you use Redis or non-transactional MySQL client tasks in the framework, the URL already contains all the information related to the connection, including: * username and password * database name or database ID * the character set for MySQL The framework will automatically log in or select a reusable connection based on the above information, and you do not need to care about the connection context. Due to this limitation, in the framework, you cannot use the SELECT command of Redis and the USE command of MySQL. If you want to switch databases, you should use a new URL to create the task. Transactional MySQL tasks can use fixed connections. Please see MySQL documentations for relevant details. However, if you implement a server based on Redis protocol, you need to know the current connection status. By using the deleter function of connection context, users can also get notified when the connection was closed by the peer. # How to use connection context Note: generally, only the server tasks need to use the connection context, and the connection context is used only inside the process function, which is also the safest and simplest. You can also use or modify the connection context in the callback, but you should consider the concurrency problem. We’ll discuss the related issues in details. You can obtain the connection object in any network task through interfaces, and then obtain or modify the connection context. [WFTask.h](../src/factory/WFTask.h) contains a sample call: ~~~cpp template class WFNetworkTask : public CommRequest { public: virtual WFConnection *get_connection() const = 0; ... }; ~~~ [WFConneciton.h ](../src/factory/WFConnection.h)contains the interfaces for performing operations on the connection objects: ~~~cpp class WFConnection : public CommConnection { public: void *get_context() const; void set_context(void *context, std::function deleter); void set_context(void *context); void *test_set_context(void *test_context, void *new_context, std::function deleter); void *test_set_context(void *test_context, void *new_context); }; ~~~ **get\_connection()** can only be called in a process or a callback. If you call it in the callback, please check whether the return value is NULL. If you get the WFConnection object successfully, you can perform operations on the connection context. A connection context is a void \* pointer. When the connection is closed, the deleter is automatically called. When using the setting context functions without ``deleter`` argument, the original deleter will be kept unchanged. # Timing and concurrency for accessing connection context When a client task is created, the connection object is not determined. Thus, for all client tasks, you can only use the connection context in the callback. For server tasks, you may use connection context in the process or the callback. When you use connection context in a callback, you need to consider concurrency, because the same connection may be reused by multiple tasks and reach the callbacks at the same time. Therefore, we recommend that the connection context should be accessed or modified only in the process function, because the connection will not be reused or released in the process, which is the simplest and safest. Note: the process in the above paragraphs means only the places inside the process function. In the places after the process function and before the callback, get\_connection() always returns NULL. **test\_set\_context()** in the WFConnection is used to solve the concurrency issues for using connection context in the callback, but it is not recommended. In a word, if you are not very familiar with the system implementation, please use the connection context only in the process function of the server tasks. # Example: how to reduce the request header fields in HTTP/1.1 HTTP protocol is a stateless connection protocol, and a complete header must be sent for every request on the same connection. If the cookie in the request is very large, it will obviously increase the data transmission overload. You can use the server-side connection context to solve this issue. You can specify that the cookie in the HTTP request is valid for all subsequent requests on the same connection, and omit the cookie in the subsequent request headers. Please see the following codes on the server side: ~~~cpp void process(WFHttpTask *server_task) { protocol::HttpRequest *req = server_task->get_req(); protocol::HttpHeaderCursor cursor(req); WFConnection *conn = server_task->get_connection(); void *context = conn->get_context(); std::string cookie; if (cursor.find("Cookie", cookie)) { if (context) delete (std::string *)context; context = new std::string(cookie); conn->set_context(context, [](void *p) { delete (std::string *)p; }); } else if (context) cookie = *(std::string *)context; ... } ~~~ In this way, if you arrange with the client that the cookie is transmitted only at the first request of the connection, the traffic can be reduced. The implementation in the client side needs to use a new **prepare** function. Please see the codes below: ~~~cpp using namespace protocol; void prepare_func(WFHttpTask *task) { if (task->get_task_seq() == 0) task->get_req()->add_header_pair("Cookie", my_cookie); } int some_function() { WFHttpTask *task = WFTaskFactory::create_http_task(...); task->set_prepare(prepare_func); ... } ~~~ In the example, when the HTTP task is the first request on the connection, the cookie is set. If it is not the first request, according to our arrangement, we do not set the cookie. In addition, you may use the connection context safely in the **prepare** function. **prepare** will not be concurrent on the same connection. workflow-0.11.8/docs/en/about-counter.md000066400000000000000000000214031476003635400201260ustar00rootroot00000000000000# About counter Counters are very important basic tasks in our framework. A counter is essentially a semaphore that does not occupy thread. Counters are mainly used for workflow control. It includes anonymous counters and named counters, and can realize very complex business logic. # Creating a counter As a counter is also a task, it is created through WFTaskFactory. You can create a counter with one of the following two methods: ~~~cpp using counter_callback_t = std::function; class WFTaskFactory { ... static WFCounterTask *create_counter_task(unsigned int target_value, counter_callback_t callback); static WFCounterTask *create_counter_task(const std::string& counter_name, unsigned int target_value, counter_callback_t callback); ... }; ~~~ Each counter contains a target\_value. When the count in the counter reaches the target\_value, its callback is called. The above two interfaces generate a anonymous counter and a named counter respectively. The anonymous counter directly increases the count through the count method in the WFCounterTask: ~~~cpp class WFCounterTask { public: virtual void count() { ... } ... } ~~~ If a counter\_name is passed when you create a counter, a named counter is generated, and the count can be increased with the count\_by\_name function. # Creating parallel tasks with anonymous counters In the example of [parallel wget](/docs/en/tutorial-06-parallel_wget.md), we created a ParallelWork to achieve the parallel execution of several series. With the combination of ParallelWork and SeriesWork, you can build series-parallel graphs in any form, which can meet the requirements in most scenarios. Counters allow us to build more complex dependencies between the tasks, such as a fully connected neural network. The following simple code can replace ParallelWork to realize parallel HTTP crawling. ~~~cpp void http_callback(WFHttpTask *task) { /* Save http page. */ ... WFCounterTask *counter = (WFCounterTask *)task->user_data; counter->count(); } std::mutex mutex; std::condition_variable cond; bool finished = false; void counter_callback(WFCounterTask *counter) { mutex.lock(); finished = true; cond.notify_one(); mutex.unlock(); } int main(int argc, char *argv[]) { WFCounterTask *counter = create_counter_task(url_count, counter_callback); WFHttpTask *task; std::string url[url_count]; /* init urls */ ... for (int i = 0; i < url_count; i++) { task = create_http_task(url[i], http_callback); task->user_data = counter; task->start(); } counter->start(); std::unique_lock lock(mutex); while (!finished) cond.wait(lock); lock.unlock(); return 0; } ~~~ The above code creates a counter with the target value as url\_count, and calls the count once after each HTTP task is completed. Note that the times **count()** a anonymous counter cannot exceed it's target value. Otherwise the counter may have been destroyed after the callback, and the program behavior is undefined. The call of **counter->start()** can be placed before the for loop. After a counter is created, you can call its count interface, no matter whether the counter has been started or not. You can also use **counter->WFCounterTask::count()** to call the count interface of an anonymous counter; this can be used in performance-sensitive applications. # Using a server together with other asynchronous engines In some cases, our server may need to call asynchronous clients in other frameworks and wait for the results. A simple method is that we wait synchronously in the process and then are waken up through conditional variables. Its disadvantage is that we occupy a processing thread and turn asynchronous clients in other frameworks into synchronous clients. But with the counter method, we can wait without occupying threads. The method is very simple: ~~~cpp void some_callback(void *context) { protocol::HttpResponse *resp = get_resp_from_context(context); WFCounterTask *counter = get_counter_from_context(context); /* write data to resp. */ ... counter->count(); } void process(WFHttpTask *task) { WFCounterTask *counter = WFTaskFactory::create_counter_task(1, nullptr); SomeOtherAsyncClient client(some_callback, context); *series_of(task) << counter; } ~~~ Here, we can consider the series of a server task as a coroutine, and the counter whose target value is 1 can be considered as a conditional variable. # Named counters When the count operation is executed on the anonymous counter, the counter object pointer is directly accessed. This inevitably requires that the number of calls to count should not exceed the target value during operation. But imagine an application scenario where we start four tasks at the same time, and as long as any three tasks are completed, the workflow can continue. We can use a counter with a target value of 3, and count once after each task is completed. As long as three tasks are completed, the callback of the counter will be executed. But the problem is that when the fourth task is finished and **counter->count()** is called again, the counter is already a wild pointer and the program crashes. In this case, we can use named counters to solve this problem. By naming the counter and counting by name, we can have the following implementation: ~~~cpp void counter_callback(WFCounterTask *counter) { WFRedisTask *next = WFTaskFactory::create_redis_task(...); series_of(counter)->push_back(next); } int main(void) { WFHttpTask *tasks[4]; WFCounterTask *counter; counter = WFTaskFactory::create_counter_task("c1", 3, counter_callback); counter->start(); for (int i = 0; i < 4; i++) { tasks[i] = WFTaskFactory::create_http_task(..., [](WFHttpTask *task){ WFTaskFactory::count_by_name("c1"); }); tasks[i]->start(); } ... } ~~~ In this example, four concurrent HTTP tasks are started, three of which are completed, and a Redis task is started immediately. In the practical application, you may need to add the code of data transmission. In the example, a counter named "c1" is created, and in the HTTP callback, call **WFTaskFactory::count\_by\_name()** to increase the count. ~~~cpp class WFTaskFactory { ... static int count_by_name(const std::string& counter_name); static int count_by_name(const std::string& counter_name, unsigned int n); ... }; ~~~ You can pass an integer n to **WFTaskFactory::count\_by\_name**, indicating the count value to be increased in this operation. Obviously: **count\_by\_name("c1")** is equivalent to **count\_by\_name("c1", 1)**. If the "c1" counter does not exist (not created or already completed), the operation on "c1" will have no effect, so the wild pointer problem in an anonymous counter will not happen here. The **count\_by\_name()** function returns the number of counters that was waked up by the operation. When **n** is greater that 1, more than one counter may reach target value. # Definition of the detailed behaviors of named counters When you **call WFTaskFactory::count\_by\_name(name, n)**: * if the name does not exist (not created or already completed), there is no action. * if there is only one counter with that name: * if the remaining value of the counter is less than or equal to n, the counting is completed, the callback is called, and the counter is destroyed. end. * if the remaining value of the counter is greater than n, the count value is increased by n. end. * if there are multiple counters with that name: * according to the order of creation, take the first counter and assume that its remaining value is m: * if m is greater than n, the count value is increased by n. end (the remaining value is m-n). * if m is less than or equal to n, the counting is completed, the callback is called, and the counter is destroyed. set n = n-m. * If n is 0, the procedure ends. * If n is greater than 0, take out the next counter with the same name and repeat the whole operation. Although the description is very complicated, it can be summed up in one sentence. Access all counters with that name according to the order of creation one by one until n is 0. In other words, one **count\_by\_name(name, n)** may wake up multiple counters. The counters can be used to realize very complex business logic if you can use them well. In our framework, counters are often used to implement asynchronous locks or to build channels between tasks. It is more like a control task in form. workflow-0.11.8/docs/en/about-dns.md000066400000000000000000000171341476003635400172410ustar00rootroot00000000000000# About DNS When using a domain name to request the network, you first need to obtain the server address through domain name resolution, and then use the network address to make subsequent requests. Workflow has implemented a complete domain name resolution and caching system. Generally speaking, users can initiate network tasks smoothly without knowing the internal mechanism. ## DNS related configuration Global configuration in Workflow includes ~~~cpp struct WFGlobalSettings { struct EndpointParams endpoint_params; struct EndpointParams dns_server_params; unsigned int dns_ttl_default; unsigned int dns_ttl_min; int dns_threads; int poller_threads; int handler_threads; int compute_threads; int fio_max_events; const char *resolv_conf_path; const char *hosts_path; }; ~~~ Among them, the configuration items related to domain name resolution include * dns_server_params * address_family: This item will be explained later * max_connections: The maximum number of concurrent requests sent to the DNS server, the default is 200 * connect_timeout/response_timeout/ssl_connect_timeout: refer to [timeout](about-timeout.md) for related instructions * dns_threads: When using synchronous mode to implement domain name resolution, the resolution operation will be executed in an independent thread pool. This item specifies the number of threads in the thread pool. The default is 4. * dns_ttl_default: The result of successful domain name resolution will be placed in the domain name cache. This item specifies its survival time in seconds. The default value is 1 hour. When the resolution result expires, it will be re-parsed to obtain the latest content. * dns_ttl_min: When communication fails, the cached result may have expired. This item specifies a shorter survival time. When communication fails, the cache is updated at a more frequent rate. The unit is seconds. The default value is 1 minute. * resolv_conf_path: This file saves the configuration related to accessing DNS. It is usually located in `/etc/resolv.conf` on common Linux distributions. If this item is configured as `NULL`, it means using multi-threaded synchronous resolution mode. * hosts_path: This file is a local domain name lookup table. If the resolved domain name hits this table, it will not initiate a request to DNS. It is usually located in `/etc/hosts` on common Linux distributions. If this item is configured as `NULL` means not to use the lookup table ### resolv.conf extensions Workflow has extended the `resolv.conf` configuration file. Users can modify the configuration to support the `DNS over TLS(DoT)`. **Note** directly modifying `/etc/resolv.conf` will affect other processes. You can make a copy of the file for modification, and modify the `resolv_conf_path` configuration of Workflow to the path of the new file. For example, a `nameserver` using the `dnss` protocol will connect via SSL ~~~bash nameserver dnss://8.8.8.8/ nameserver dnss://[2001:4860:4860::8888]/ ~~~ ### Address Family In some network environments, although the machine supports IPv6, it cannot communicate with the outside because it has not been assigned a public IPv6 address (for example, the local IPv6 address starts with `fe80`). At this time, you can set `endpoint_params.address_family` to `AF_INET` to force only IPv4 addresses to be resolved during domain name resolution. Similarly, the `resolv.conf` file may specify both the IPv4 address and the IPv6 address of the `nameserver`. In this case, you can set `dns_server_params.address_family` to `AF_INET` or `AF_INET6` to force the use of only IPv4 or IPv6 addresses to access DNS. ### Use Upstream configuration The global configuration takes effect for each domain name by default. If you need to specify different configurations for certain domain names, you can use the [Upstream](./about-upstream.md#Address attribute) function. Using Upstream, you can individually specify the `dns_ttl_default` and `dns_ttl_min` configuration items, and individually specify the IP address family used by the domain name through `endpoint_params.address_family`. ## Domain name resolution and caching strategy Network tasks usually require domain name resolution to obtain the IP address that needs to be accessed. The relevant strategies for domain name resolution in Workflow are as follows: 1. Check whether the domain name cache has the IP address corresponding to the domain name. If there is a cache and it has not expired, use this set of IP addresses. 2. Check whether the domain name is an IPv4, IPv6 address or `Unix Domain Socket`. If so, use the address directly without initiating domain name resolution. 3. Check whether the `hosts_path` file contains the IP address corresponding to the domain name. If so, use the address directly. 4. Obtain an asynchronous lock to ensure that a resolution request for the same domain name is only initiated once at the same time, and initiate a resolution request to DNS 5. After successful parsing, the parsing result will be saved to the domain name cache of the current process for next use, and the asynchronous lock will be released. 6. After the parsing fails, the asynchronous lock will be released and the failure reason will be notified to all tasks waiting on the same asynchronous lock. New tasks initiated after the notification is completed will request DNS again. Many scenarios that require a large number of network requests will be equipped with a domain name caching component. If a resolution request is sent to the DNS every time a network task is initiated, the DNS will inevitably be overwhelmed. Workflow sets the cache survival time (dns_ttl_default and dns_ttl_min) to ensure that the cache will expire after a reasonable period of time and the domain name resolution results can be updated in a timely manner. When a cache item of a domain name expires, the first task found to be expired will extend its survival time by 5 seconds and initiate a resolution request to DNS. Requests on the same domain name within 5 seconds will directly use the cached DNS resolution results without waiting. The asynchronous lock mechanism can ensure that the resolution request for the **same domain name** is only initiated once at the same time. Without lock protection, if a large number of network tasks are initiated for the same domain name in a short period of time, each task will be unable to be retrieved from the cache. Too many resolution request to DNS will place a large and unnecessary burden on DNS. The same domain name here represents the `(host, port, family)` triplet. If a domain name is required to only use IPv4 and IPv6 through Upstream, they will be protected by different asynchronous locks, and it is possible to request DNS at the same time. ### Asynchronous domain name resolution Workflow implements a complete DNS task. If the `resolv_conf_path` configuration item is specified, an asynchronous request will be used when initiating domain name resolution to DNS. Under Unix-like systems, Workflow uses `/etc/resolv.conf` as the value of this configuration by default. Asynchronous domain name resolution does not block any threads or monopolize the thread pool, and can complete the task of domain name resolution more efficiently. ### Synchronous domain name resolution If `resolv_conf_path` is specified as `NULL`, synchronous domain name resolution will be achieved by calling the `getaddrinfo` function. This method will use an independent thread pool, and the number of threads is configured through the `dns_threads` parameter. If a large number of domain name resolution requests need to be initiated in a short period of time, the synchronization method will cause a large delay. workflow-0.11.8/docs/en/about-error.md000066400000000000000000000103561476003635400176050ustar00rootroot00000000000000# About error handling Error handling is an important and complex problem in any software system. Within our framework, error handling is ubiquitous and extremely cumbersome. In the interfaces we exposed to users, we try to make things as simple as possible, but users still inevitably need to know some error messages. ### Disabling C++ exceptions C++ exceptions are not used in our framework. When you compile your own code, it is best to add **-fno-exceptions** flag to reduce the code size. According to the common practice in the industry, we ignore the possibility of the failure of **new** operation, and avoid using new to allocate large blocks of memory internally. And there are error checks in memory allocation in C style. ### About factory functions From the previous examples, you can see that all task and series are generated from two factory classes, WFTaskFactory or Workflow. These factory classes, as well as more factory class interfaces that we may encounter in the future, ensure success. In other words, they never return NULL. And you do not need to check the return value. To achieve this goal, even when the URL is illegal, the factory still generates the task normally. And you will get the error in the callback of the task. ### States and error codes of a task In the previous examples, you often see such codes in the callback: ~~~cpp void callback(WFXxxTask *task) { int state = task->get_state(); int error = task->get_error(); ... } ~~~ in which, the state indicates the end state of a task. [WFTask.h](/src/factory/WFTask.h) contains all possible states: ~~~cpp enum { WFT_STATE_UNDEFINED = -1, WFT_STATE_SUCCESS = CS_STATE_SUCCESS, WFT_STATE_TOREPLY = CS_STATE_TOREPLY, /* for server task only */ WFT_STATE_NOREPLY = CS_STATE_TOREPLY + 1, /* for server task only */ WFT_STATE_SYS_ERROR = CS_STATE_ERROR, WFT_STATE_SSL_ERROR = 65, WFT_STATE_DNS_ERROR = 66, /* for client task only */ WFT_STATE_TASK_ERROR = 67, WFT_STATE_ABORTED = CS_STATE_STOPPED /* main process terminated */ }; ~~~ ##### Please note the following states: * SUCCESS: the task is successfully completed. The client receives the complete reply, or the server writes the reply completely into the send buffer (but there is no guarantee that the peer will receive it). * SYS\_ERROR: system error. In this case, use **task->get\_error()** to get the system error code **errno**. * When **get\_error()** gets ETIMEDOUT, you can call **task->get\_timeout\_reason()** to get the timeout reasons. * DNS\_ERROR: DNS resolution error. Use **get\_error()** to get the return code of **getaddrinfo()**. For DNS, please see the article for details [about-dns.md](/docs/en/about-dns.md). * The server task never has a DNS\_ERROR. * SSL\_ERROR: SSL error. Use **get\_error()** to get the return value of **SSL\_get\_error()**. * Currently SSL error information is not complete, and you can not get the value of **ERR\_get\_error()**. Therefore, basically there are three possible return value of **get\_error()**: * SSL\_ERROR\_ZERO\_RETURN, SSL\_ERROR\_X509\_LOOKUP, SSL\_ERROR\_SSL. * We will consider adding more detailed SSL error information in the future versions. * TASK\_ERROR: task errors. Common errors include illegal URL, login failure, etc. [WFTaskError.h](/src/factory/WFTaskError.h) lists the return values of **get\_error()**. ##### You do not need to pay attention to the following states: * UNDEFINED: Client tasks that have just been created and have not yet been run are in UNDEFINED state. * TOREPLY: Server tasks that have not sent replies or called **task->noreply()** are in TOREPLY state. * NOREPLY: Server tasks that have called **task->noreply()** are always in NOREPLY state. The callback of these tasks are also in NOREPLY state. And the connection will be closed. ### Other error handling requirements In addition to the error handling of the task itself, you also need to check the errors of the message interfaces of various protocols. Generally, these interfaces indicate errors by returning false, and show the error reasons in the errno. In addition, you may encounter more complicated error messages when you use some complex operations. You will learn them in detailed documents. workflow-0.11.8/docs/en/about-exit.md000066400000000000000000000141321476003635400174210ustar00rootroot00000000000000# About exit As most of our calls are non-blocking, we need some mechanisms to prevent the main function from exiting early in the previous examples. For example, in the wget example, we wait for the user's Ctrl-C, or in the parallel\_wget example, we wake up the main thread after all crawling tasks are finished. In several server examples, the **stop()** operation is blocking, which can ensure the normal end of all server tasks and the safe exit of the main thread. # Principles on the safe exit Generally, as long as you writes the program normally and follows the methods in the examples, there is little doubt about exit. However, it is still necessary to define the conditions for normal program exit. * You can't call **exit()** of the system in any callback functions such as the callback or the process, otherwise the behavior is undefined. * The condition that a main thread can safely end (call **exit()** or return in the main function) is that all tasks have been run to callbacks and no new tasks is started. * All our examples are consistent with this assumption, waking up the main function in the callback. This is safe, and there is no need to worry about the situation where the callback is not finished when the main function returns. * ParallelWork is a kind of tasks, which also needs to run to its callback. * This rule can be violated under certain circumstances. We will talk about it in the following section. * All server must stop, otherwise the behavior is undefined. Because all users know how to call the stop operation, generally a server program will not have any exit problems. * Server's stop() method will block until all server tasks' series end. But if you start a task directly in process function, you have to take care of the end this task. # Why do I need to wait for the callback of a running task? Can the program be ended early? First, explain why you need to wait till the callback of tasks before ending the program. In most cases, the tasks generated through the task factory are composite tasks. For example, an http client task may need to resolve the dns, and then initiate the http crawl. And if a 302 redirect is encountered, dns resolving may be needed again. If the task fails, retrying may be involved. In other words, any asynchronous tasks may contain multiple asynchronous processes, but it is completely insensitive to users. But between each internal asynchronous process, it does not check whether the program has exited. If the user clearly knows that a task is an atomic task, for example, an http task created with an IP address (or a dns cache can definitely be hit), and there is no redirection or retry. Then, this task can be interrupted by the program's exit and come to the callback early, and the state of the task in the callback is WFT_STATE_ABORTED. For example, the following program is always safe: ~~~cpp void callback(WFHttpTask *task) { // most probably print 2,WFT_STATE_ABORTED。 printf("state = %d\n", task->get_state()); } int main() { WFHttpTask *task = WFTaskFactory::create_http_task("https://127.0.0.1/", 0, 0, callback); task->start(); // end the main process directly return 1; } ~~~ If the dns cache hits, it is safe. Because there is no need to initiate a dns asynchronous task internally. E.g: ~~~cpp WFFacilities::WaitGroup wg(1) void callback_normal(WFHttpTask *task) { wg.done(); } void callback_abort(WFHttpTask *task) { // most probably print 2,WFT_STATE_ABORTED。 printf("state = %d\n", task->get_state()); } int main() { WFHttpTask *task = WFTaskFactory::create_http_task("https://www.sogou.com/", 3, 2, callback_normal); task->start(); // wait for the end of the first task wg.wait(); // Access wwww.sogou.com again. Hit the dns cache definitely. WFHttpTask *task = WFTaskFactory::create_http_task("https://www.sogou.com/", 0, 0, callback_abort); task->start(); // end the main process directly return 1; } ~~~ Therefore, for a network task, as long as it can be determined to be an atomic task, it can be interrupted by the end of the program. This principle can be extended to any type of task. For example, the timer task is an atomic task, and the following program is also safe: ~~~cpp void callback(WFTimerTask *task) { // definitely print 2,WFT_STATE_ABORTED。 printf("state = %d\n", task->get_state()); } int main() { WFTimerTask *task = WFTaskFactory::create_timer_task(1000000, callback); task->start(); // end the main process directly return 1; } ~~~ In the documentation (About Timer)(https://github.com/sogou/workflow/blob/master/docs/en/about-timer.md), we will describe them in detail. In addition, you can also end the program before the callback of single-threaded computing tasks and file IO tasks. Among them, the computing task that is already running, the program will wait for the task to end, and finally callback in the SUCCESS state. If it has not begun running, it will canceled and you will get an ABORTED state in callback. As long as the file IO task has been started, it will always wait for the IO to complete. Therefore, it is always safe to exit the program directly. # About memory leakage of OpenSSL 1.1 in exiting We found that some OpenSSL 1.1 versions have the problem of incomplete memory release in exiting. The memory leak can be seen by Valgrind memcheck tool. This problem only happens when you use SSL, such as crawling HTTPS web pages, and usually you can ignore this leak. If it must be solved, you can use the following method: ~~~cpp #include int main() { #if OPENSSL_VERSION_NUMBER >= 0x10100000L OPENSSL_init_ssl(0, NULL); #endif ... } ~~~ In other words, before using our library, you should initialize OpenSSL. You can also configure OpenSSL parameters at the same time if necessary. Please note that this function is only available in OpenSSL version 1.1 or above, so you need to check the openSSL version before calling it. This memory leak is related to the memory release mechanism of OpenSSL 1.1. The solution provided by us can solve this problem (but we still recommend you to ignore it). workflow-0.11.8/docs/en/about-go-task.md000066400000000000000000000062461476003635400200240ustar00rootroot00000000000000# About go task We provide a simpler way to use computing task, which is inspired by the golang, and we name it 'go task'. When using go task, no input nor output type has to be defined. All data are passed through function's arguments. # Creating a go task ~~~cpp class WFTaskFactory { ... public: template static WFGoTask *create_go_task(const std::string& queue_name, FUNC&& func, ARGS&&... args); }; ~~~ # Example We want to run an 'add' function asychronously: void add(int a, int b, int& res); Still, we want the result printed after the 'add' function is finished. We may create a go task: ~~~cpp #include #include #include "workflow/WFTaskFactory.h" #include "workflow/WFFacilities.h" void add(int a, int b, int& res) { res = a + b; } int main(void) { WFFacilities::WaitGroup wait_group(1); int a = 1; int b = 1; int res; WFGoTask *task = WFTaskFactory::create_go_task("test", add, a, b, std::ref(res)); task->set_callback([&](WFGoTask *task) { printf("%d + %d = %d\n", a, b, res); wait_group.done(); }); task->start(); wait_group.wait(); return 0; } ~~~ The above example runs an add function asynchronously, prints the result and exits normally. The creating and running of go task have little difference from other kinds of tasks, and the user_data field is also available. Note that when creating a go task, we donot pass a callback function. But you may set_callback later like other kinds of tasks. If an argument of the go task's function is a reference, you should use `std::ref` when passing it to the task, otherwise it will be passed as a value. # Go task with running time limit You may create a go task with running time limit by calling WFTaskFactory::create_timedgo_task(): ~~~cpp class WFTaskFactory { /* Create 'Go' task with running time limit in seconds plus nanoseconds. * If time exceeded, state WFT_STATE_SYS_ERROR and error ETIMEDOUT will be got in callback. */ template static WFGoTask *create_timedgo_task(time_t seconds, long nanoseconds, const std::string& queue_name, FUNC&& func, ARGS&&... args); }; ~~~ Compared with creating a normal go task, the ``create_timedgo_task`` function needs to pass two more parameters, seconds and nanoseconds. If the running time of ``func`` reaches the seconds+nanosconds time limit, the task callback directly, and the state is WFT_STATE_SYS_ERROR and the error is ETIMEDOUT. Note that the framework cannot interrupt the user's ongoing task. ``func`` will still continue to execute to the end, but will not callback again. In addition, the value range of nanoseconds is [0,1 billion). # Use the whole library as a thread pool You may use go task only. In this way the workflow library becomes a thread pool,and the default thread number is equal to the cpu number of the host. But this thread pool has some special features. Every thread task is associated with a queue name that will indicate scheduling, and you may set up the dependency of all tasks too. workflow-0.11.8/docs/en/about-module.md000066400000000000000000000066471476003635400177510ustar00rootroot00000000000000# About Module Task Our **series** has tasks as elements. But in many cases, users need module-level encapsulation, such as several tasks to complete a specific function. With the original method, you have to let the callback of the last task connect to the next task, or fill in the response of the server task. Therefore, we introduced WFModuleTask, which is convenient for users to encapsulate modules and reduce the coupling of tasks between different functional modules. # Create a Module Task We define a **module** as a kind of task, WFModuleTask. Inside the module includes a sub_series for running tasks within the module. Any task doesn't need to care if it runs inside a module. Because the sub_series inside the module is no different from the normal series. In [WFTaskFactory.h](/src/factory/WFTaskFactory.h), the creation interface of module task: ~~~cpp using module_callback_t = std::function; class WFTaskFactory { static WFModuleTask *create_module_task(SubTask *first, module_callback_t callback); }; ~~~ The first create_module_task() is the first task of the module. Similar to creating a series. The module task’s callback request a **const** pointer argument in order to prevent user pushing more tasks to module in callback. # WFModuleTask Interfaces Because we define modules as this kind of task, we can use modules like any other task. But modules do not have **state** and **error** fields. In [WFTask.h](/src/factory/WFTask.h), we define the class of WFModuleTask: ~~~cpp class ModuleTask : public ParallelTask, protected SeriesWork { public: void start() { .. } void dismiss() { ... } public: SeriesWork *sub_series() { return this; } const SeriesWork *sub_series() const { return this; } public: void *user_data; }; ~~~ The **sub_series** interface returns the series of tasks running in the module. A module is essentially a sub-flow. sub_series is also an ordinary series, and users can call its set_context(), get_context(), push_back() and other functions. But we don't recommend setting a callback for sub_series, use module task’s callback instead. # Example In the processing logic of an http server, we design all processing logic as a module. ~~~cpp struct ModuleCtx { std::string body; }; void http_callback(WFHttpTask *http_task) { SeriesWork *series = series_of(http_task); // This series is module’s sub_series struct ModuleCtx *ctx = (struct ModuleCtx *)series->get_context(); const void *body; size_t size; If (http_task->get_resp()->get_parsed_body(&body, &size)) { ctx->body.assign(body, size); } ParallelWork *pwork = Workflow::create_parallel_work(…);// Do some other things series->push_back(pwork); } void process(WFHttpTask *server_task) { WFHttpTask *http_task = WFTaskFactory::create_http_task(…, http_callback); WFModuleTask *module = WFTaskFactory::create_module_task(http_task, [server_task](const WFModuleTask *mod) { struct ModuleCxt *ctx = (ModuleCtx *)mod->sub_series()->get_context(); server_task->get_resp()->append_output_body(ctx->body); delete ctx; }); module->sub_series()->set_context(new ModuleCtx); series_of(server_task)->push_back(module); } ~~~ Through this method, the tasks in the module only need to operate the series context, and finally the **resp** is filled in by the callback of the module. Task coupling is greatly reduced. workflow-0.11.8/docs/en/about-resource-pool.md000066400000000000000000000070401476003635400212460ustar00rootroot00000000000000# Conditional task and resource pool When we use workflow to write asynchronous programs, we often encounter such scenarios: * A task needs to obtain a resource from a certain pool before running, and put it back to the pool after it finishs. * We may need to limit the max concurrency of accessing one or more communication targets. But don't want to occupy a thread when waiting. * We have many tasks that arrive randomly, in different series. But these tasks must be run serially. All these needs can be solved with the resource pool module. Our [WFDnsResolver](https://github.com/sogou/workflow/blob/master/src/nameservice/WFDnsResolver.cc) uses this method to control the concurrency of querying the dns server. # Interfaces of resource pool In [WFResourcePool.h](https://github.com/sogou/workflow/blob/master/src/factory/WFResourcePool.h) we define the interfaces of resource pool: ~~~cpp class WFResourcePool { public: WFConditional *get(SubTask *task, void **resbuf); WFConditional *get(SubTask *task); void post(void *res); ... protected: virtual void *pop() { return this->data.res[this->data.index++]; } virtual void push(void *res) { this->data.res[--this->data.index] = res; } ... public: WFResourcePool(void *const *res, size_t n); WFResourcePool(size_t n); ... }; ~~~ ### Constructors The first constructor accept a resource array, with the lenght n. Each element of the array is a **void \*** representing a resource. The whole array will be copied by the constructor. If all the initial resources are **nullptr**, you may use the second constructor which has only one argument n, representing the number of resources. You may take a look of the implementation codes: ~~~cpp void WFResourcePool::create(size_t n) { this->data.res = new void *[n]; this->data.value = n; ... } WFResourcePool::WFResourcePool(void *const *res, size_t n) { this->create(n); memcpy(this->data.res, res, n * sizeof (void *)); } WFResourcePool::WFResourcePool(size_t n) { this->create(n); memset(this->data.res, 0, n * sizeof (void *)); } ~~~ ### Application interfaces Users use **get()** method of resource pool to wrap a task. **get()** returns a conditional, which is also a task. Conditional will runs the task it wrap when it obtain a resource from the pool. **get()** may accept a second argument **void \*\* resbuf**, which is the buffer that will store the resource abtained. After the **get()** operation, users can use the returned conditional to substain the original task. It can be started or put to any series just like an ordinary task. After the user task is finished, **post()** need to be called to return a resource to the pool. Typically, **post()** is called in user task's callback. ### Derivation The using of resource pool is FILO. It means the last released resource will be the next one to be obtained. You may subclass WFResourcePool to implement a FIFO pool. ### Example We have a URL list to be crawled. But we limit the max concurreny of crawling task to be **max_p**. We may use ParallelWork to implement this function of course. But with resource pool, everything is much simpler: ~~~cpp int fetch_with_max(std::vector& url_list, size_t max_p) { WFResourcePool pool(max_p); for (std::string& url : url_list) { WFHttpTask *task = WFTaskFactory::create_http_task(url, [&pool](WFHttpTask *task) { pool.post(nullptr); }); WFConditional *cond = pool.get(task); cond->start(); } // wait_here... } ~~~ workflow-0.11.8/docs/en/about-service-governance.md000066400000000000000000000245211476003635400222400ustar00rootroot00000000000000# About service governance We have a complete mechanism to manage the services we depend on. This mechanism includes the following functions: * User level DNS. * Selection of service addresses. * Including a variety of selection mechanisms, such as random weight, consistent hash, manual selection methods, etc. * Service circuit breaker and recovery. * Load balancing. * Configuring independent parameters for a single service. * Main/backup relations for a service, etc. All these functions depend on our upstream subsystem. By making good use of this system, we can easily implement more complex service mesh functions. # upstream name upstream name is equivalent to the domain name inside the program. However, compared with the general domain name, upstream has more functions, including: * Generally, a domain name can only point to a set of IP addresses; an upstream name can point to a set of IP addresses or domain names. * The objects (domain names or IPs) pointed by the upstream may include port information. * upstream has powerful functions for managing and selecting targets, and each target can contain a large number of attributes. * upstream update is real-time and completely thread-safe, while the DNS of domain names cannot be updated in real time. In practice, if you don't need to access the external network, the domain names and DNS can be completely replaced by upstream. # Creating and deleting upstream [UpstreamMananer.h](/src/manager/UpstreamManager.h) contains several interfaces for creating upstream: ~~~cpp using upstream_route_t = std::function; class UpstreamManager { public: static int upstream_create_consistent_hash(const std::string& name, upstream_route_t consitent_hash); static int upstream_create_weighted_random(const std::string& name, bool try_another); static int upstream_create_manual(const std::string& name, upstream_route_t select, bool try_another, upstream_route_t consitent_hash); static int upstream_delete(const std::string& name); ... }; ~~~ The three functions create three types of upstream: consistent hash, weighted random and manual selection. The parameter **name** means upstream name, which is used in the same way as a domain name after creation. **consistent\_hash** and **select** parameters are both **std::function** of **upstream\_route\_t**, which are used to specify the routing method. And try\_another indicates whether to continue trying to find an available target if the selected target is unavailable (blown). consistent\_hash mode does not have this attribute. The upstream\_route\_t parameter receives three parameters: path, query and fragment in a URL. For example, if the URL is http://abc.com/home/index.html?a=1#bottom, the three parameters are "/home/index.html", "a=1” and "bottom” respectively. Based on these three parts, the system can select the target server or perform consistent hashing. Please note that you call pass nullptr to all consistent\_hash parameters in the above interfaces, and the framework will use the default consistent hash algorithm. # Example 1: weight allocation We want to allocate 50% of the requests to www.sogou.com to 127.0.0.1:8000 and 127.0.0.1:8080, and make their load be 1:4. We don't need to care about the number of IP addresses behind the domain name www.sogou.com. In short, the actual domain name will receive 50% of the requests. ~~~cpp #include "workflow/UpstreamManager.h" #include "workflow/WFTaskFactory.h" int main() { UpstreamManager::upstream_create_weighted_random("www.sogou.com", false); struct AddressParams params = ADDRESS_PARAMS_DEFAULT; params.weight = 5; UpstreamManager::upstream_add_server("www.sogou.com", "www.sogou.com", ¶ms); params.weight = 1; UpstreamManager::upstream_add_server("www.sogou.com", "127.0.0.1:8000", ¶ms); params.weight = 4; UpstreamManager::upstream_add_server("www.sogou.com", "127.0.0.1:8080", ¶ms); WFHttpTask *task = WFTaskFactory::create_http_task("http://www.sogou.com/index.html", ...); ... } ~~~ Please note that these functions can be called in any scenario. They are completely thread-safe and takes effect instantly. In addition, because all our protocols, including user-defined protocols, have URLs, the upstream function can be applied to all protocols. # Example 2: manual selection In the same example as above, we want to allocate 127.0.0.1:8000 if the query in the request URLs is "123", port 8080 if the query is "abc", and normal domain names for other requests. ~~~cpp #include "workflow/UpstreamManager.h" #include "workflow/WFTaskFactory.h" int my_select(const char *path, const char *query, const char *fragment) { if (strcmp(query, "123") == 0) return 1; else if (strcmp(query, "abc") == 0) return 2; else return 0; } int main() { UpstreamManager::upstream_create_manual("www.sogou.com", my_select, false, nullptr); UpstreamManager::upstream_add_server("www.sogou.com", "www.sogou.com"); UpstreamManager::upstream_add_server("www.sogou.com", "127.0.0.1:8000"); UpstreamManager::upstream_add_server("www.sogou.com", "127.0.0.1:8080"); /* This URL will route to 127.0.0.1:8080 */ WFHttpTask *task = WFTaskFactory::create_http_task("http://www.sogou.com/index.html?abc", ...); ... } ~~~ Because Redis and MySQL protocols are provided natively, it is very convenient to realize the read-write separation function of the database with this method (Note: non-transactional operation). In the above two examples, the upstream name is www.sogou.com, which is also a domain name. Of course, you can use a simpler string sogou as upstream name. Thus: ~~~cpp WFHttpTask *task = WFTaskFactory::create_http_task("http://sogou/home/1.html?abc", ...); ~~~ In a word, if the host part of the URL is a created upstream, it will be used as an upstream. # Example 3: consistent hash In this scenario, we will randomly select one machine from 10 Redis instances and communicate with it. But we must ensure that the same URL always accesses the same specific target. The method is very simple: ~~~cpp int main() { UpstreamManager::upstream_create_consistent_hash("redis.name", nullptr); UpstreamManager::upstream_add_server("redis.name", "10.135.35.53"); UpstreamManager::upstream_add_server("redis.name", "10.135.35.54"); UpstreamManager::upstream_add_server("redis.name", "10.135.35.55"); ... UpstreamManager::upstream_add_server("redis.name", "10.135.35.62"); auto *task = WFTaskFactory::create_redis_task("redis://:mypassword@redis.name/2?a=hello#111", ...); ... } ~~~ Our Redis task does not recognize the query part, so you can fill it out at will. 2 in the path indicates the Redis database ID. At this time, the consistent\_hash function will get three parameters: "/2", "a=hello" and "111". Because we use nullptr, the default consistent hash will be called. As we does not specify the port number for the server in upstream, it will use the port in the URL. The default port of Redis is 6379. There is no try\_another option for consitent\_hash. If the target is blown, another one will be automatically selected. The same URL will always get the same server (cache friendly). # Parameters of upstream server In Example 1, we set the weight of a server through params. But the server parameters is far more than just a weight. Its struct is defined as follows: ~~~cpp // In EndpointParams.h struct EndpointParams { size_t max_connections; int connect_timeout; int response_timeout; int ssl_connect_timeout; bool use_tls_sni; }; // In ServiceGovernance.h struct AddressParams { struct EndpointParams endpoint_params; ///< Connection config unsigned int dns_ttl_default; ///< in seconds, DNS TTL when network request success unsigned int dns_ttl_min; ///< in seconds, DNS TTL when network request fail /** * - The max_fails directive sets the number of consecutive unsuccessful attempts to communicate with the server. * - After 30s following the server failure, upstream probe the server with some alive client’s requests. * - If the probes have been successful, the server is marked as an alive one. * - If max_fails is set to 1, it means server would out of upstream selection in 30 seconds when failed only once */ unsigned int max_fails; ///< [1, INT32_MAX] max_fails = 0 means max_fails = 1 unsigned short weight; ///< [1, 65535] weight = 0 means weight = 1. only for main int server_type; ///< 0 for main and 1 for backup int group_id; ///< -1 means no group. Backup without group will backup for any main }; ~~~ Most of the parameters are self-explanatory. Among these parameters, endpoint\_params, dns and other parameters will override the global configuration. For example, if the global maximum number of connections to each target IP is 200, but you want to set a maximum of 1000 connections for 10.135.35.53, please follow the instructions below: ~~~cpp UpstreamManager::upstream_create_weighted_random("10.135.35.53", false); struct AddressParams params = ADDRESS_PARAMS_DEFAULT; params.endpoint_params.max_connections = 1000; UpstreamManager::upstream_add_server("10.135.35.53", "10.135.35.53", ¶ms); ~~~ max\_fails parameter indicates the maximum number of failure. If the selected target continuously fails, and the number of failure reaches max\_failures, it will enter the fusing state. If the try\_another attribute of upstream is false, the task will fail. In the callback of the task, get\_state()=WFT\_STATE\_TASK\_ERROR,get\_error()=WFT\_ERR\_UPSTREAM\_UNAVAILABLE. If try\_another is true and all server are blown, you will get the same error. The fusing time is 30 seconds. Server\_type and group\_id are used for main/backup features. All upstream must have a server whose type is 0, representing main, otherwise the upstream is unavailable. Backup servers (server_type 1) will be used when the main servers of the same group\_id is blown. For more information on the features of upstream, please see [about-upstream.md](/docs/en/about-upstream.md). workflow-0.11.8/docs/en/about-timeout.md000066400000000000000000000231161476003635400201400ustar00rootroot00000000000000# About timeout In order to make all communication tasks run as accurately as expected by users, the framework provides a large number of timeout configuration functions and ensure the accuracy of these timeouts. Some of these timeout configurations are global, such as connection timeout, but you may configure your own connection timeout for a perticular domain name through the upstream. Some timeouts are task-level, such as sending a message completely, because users needs to dynamically configure this value according to the message size. Of course, a server may have its own overall timeout configuration. In a word, timeout is a complicated matter, and the framework will do it accurately. All timeouts are in **poll** style. It is an **int** in milliseconds and -1 means infinite. In addition, as said in the project introduction, you can ignore all the configurations, and adjust them when you meet the actual requirements. ### Timeout configuration for basic communication [EndpointParams.h](/src/manager/EndpointParams.h) contains the following items: ~~~cpp struct EndpointParams { size_t max_connections; int connect_timeout; int response_timeout; int ssl_connect_timeout; }; static constexpr struct EndpointParams ENDPOINT_PARAMS_DEFAULT = { .max_connections = 200, .connect_timeout = 10 * 1000, .response_timeout = 10 * 1000, .ssl_connect_timeout = 10 * 1000, }; ~~~ in which there are three DNS-related configuration items. Please ignore them right now. Items related to timeout: * connect\_timeout: timeout for establishing a connection with the target. The default value is 10 seconds. * response\_timeout: timeout for waiting for the target response; the default value is 10 seconds. It is the timeout for sending a block of data to the target or reading a block of data from the target. * ssl\_connect\_timeout: timeout for completing SSL handshakes with the target. The default value is 10 seconds. This struct is the most basic configuration for the communication connection, and almost all subsequent communication configurations contain this struct. ### Global timeout configuration You can see the global settings in [WFGlobal.h](/src/manager/WFGlobal.h). ~~~cpp struct WFGlobalSettings { EndpointParams endpoint_params; unsigned int dns_ttl_default; unsigned int dns_ttl_min; int dns_threads; int poller_threads; int handler_threads; int compute_threads; }; static constexpr struct WFGlobalSettings GLOBAL_SETTINGS_DEFAULT = { .endpoint_params = ENDPOINT_PARAMS_DEFAULT, .dns_ttl_default = 12 * 3600, /* in seconds */ .dns_ttl_min = 180, /* reacquire when communication error */ .dns_threads = 8, .poller_threads = 2, .handler_threads = 20, .compute_threads = -1 }; //compute_threads<=0 means auto-set by system cpu number ~~~ in which there is one timeout related configuration item: EndpointParams endpoint\_params You can perform operations like the following to change the global configuration before calling any of our factory functions: ~~~cpp int main() { struct WFGlobalSettings settings = GLOBAL_SETTINGS_DEFAULT; settings.endpoint_params.connect_timeout = 2 * 1000; settings.endpoint_params.response_timeout = -1; WORKFLOW_library_init(&settings); } ~~~ The above example changes the connection timeout to 2 seconds, and the server response timeout is infinite. In this configuration, the timeout for receiving complete messages must be configured in each task, otherwise it may fall into infinite waiting. The global configuration can be overridden by the configuration for an individual address in the upstream feature. For example, you can specify a connection timeout for a specific domain name. In Upstream, each AddressParams also has the EndpointParams endpoint\_params item, and you can configure it in the same way as you configure the Global item. For the detailed structures, please see [upstream documents.](/docs/en/tutorial-10-upstream.md#Address) ### Configuring server timeout The [http\_proxy](/docs/en/tutorial-05-http_proxy.md) example demonstrates the server startup configuration. In which the timeout-related configuration items include: * peer\_response\_timeout: its definition is the same as the global peer\_response\_timeout, which indicates the response timeout of the remote client, and the default value is 10 seconds. * receive\_timeout: timeout for receiving a complete request. The default value is -1. * keep\_alive\_timeout: timeout for keeping a connection. The default value is 1 minute. For a Redis server, the default value is 5 minutes. * ssl\_accept\_timeout: timeout for completing SSL handshakes. The default value is 10 seconds. Under this default configuration, the client can send one byte every 9 seconds, so that the server can always receive it and no timeout occurs. Therefore, if the service is used for public network, you need to configure receive\_timeout. ### Configuring task-level timeout Task-level timeout configuration is accomplished through calling several interfaces in a network task: ~~~cpp template class WFNetworkTask : public CommRequest { ... public: /* All in milliseconds. timeout == -1 for unlimited. */ void set_send_timeout(int timeout) { this->send_timeo = timeout; } void set_receive_timeout(int timeout) { this->receive_timeo = timeout; } void set_keep_alive(int timeout) { this->keep_alive_timeo = timeout; } void set_watch_timeout(int timeout) { this->watch_timeo = timeout; } ... } ~~~ In the above code, **set\_send\_timeout()** sets the timeout for sending a complete message, and the default value is -1. **set\_receive\_timeout()** is only valid for the client task, and it indicates the timeout for receiving a complete server reply. The default value is -1. * The receive\_timeout of a server task is in the server startup configuration. All server tasks handled by users have successfully received complete requests. **set\_keep\_alive()** interface sets the timeout for keeping a connection. Generally, the framework can handle the connection maintenance well, and you do not need to call it. When an HTTP protocol is used, if a client or a server wants to use short connection, you can add an HTTP header to support it. Please do not modify it with this interface if you have other options. If a Redis client wants to close the connection after a request, you need to use this interface. Obviously, **set\_keep\_alive()** is invalid in the callback (the connection has been reused). **set\_watch\_timeout()** is specific for client task only. It indicate the maximum time of waiting the first response package. This may prevent the client task from being timed out by the limit of **response\_timeout** and **receive\_timeout**. The framework will caculate **receive\_timeoout** after receiving the first package if **watch\_timeout** is set. ### Timeout for synchronous task waiting There is a very special timeout configuration, and it is the only global synchronous waiting timeout. It is not recommended, but you can get good results with it in some application scenarios. In the current framework, the target server has a connection limit (you can set it in both global and upstream configurations). If the number of connections have reached the upper limit, the client task fails and returns an error by default. In the callback, **task->get\_state ()** gets WFT\_STATE\_SYS\_ERROR, and **task->get\_error()** gets EAGAIN. If the task is configured with retry, a retry will be automatically initiated. Here, it is allowed to configure a synchronous waiting timeout through the **task->set\_wait\_timeout()** interface. If a connection is released during this time period, the task can occupy this connection. If you sets wait\_timeout and does not get the connection before the timeout, the callback will get WFT\_STATE\_SYS\_ERROR status and ETIMEDOUT error. ~~~cpp class CommRequest : public SubTask, public CommSession { public: ... void set_wait_timeout(int wait_timeout) { this->wait_timeout = wait_timeout; } } ~~~ ### Viewing the reasons for timeout Communication tasks contain a **get\_timeout\_reason()** interface, which is used to return the timeout reason, but the reason is not very detailed. It includes the following return values: * TOR\_NOT\_TIMEOUT: not a timeout. * TOR\_WAIT\_TIMEOUT: timed out for synchronous waiting * TOR\_CONNECT\_TIMEOUT: connection timed out. The connections on TCP, SCTP, SSL and other protocols all use this timeout. * TOR\_TRANSMIT\_TIMEOUT: timed out for all transmissions. It is impossible to further distinguish whether it is in the sending stage or in the receiving stage. It may be refined later. * For a server task, if the timeout reason is TRANSMIT\_TIMEOUT, it must be in the stage of sending replies. ### Implementation of timeout functions Within the framework, there are more types of timeouts than those we show here. Except for wait\_timeout, all of them depend on the timer\_fd on Linux or kqueue timer on BSD system, one for each poller thread. By default, the number of poller threads is 4, which can meet the requirements of most applications. The current timeout algorithm uses the data structure of linked list and red-black tree. Its time complexity is between O(1) and O(logn), where n is the fd number of the a poller thread. Currently timeout processing is not the bottleneck, because the time complexity of related calls of epoll in Linux kernel is also O(logn). If the time complexity of all timeouts in our framework reaches O(1), there is no much difference. workflow-0.11.8/docs/en/about-timer.md000066400000000000000000000104511476003635400175700ustar00rootroot00000000000000# About timer Timers are used to specify a certain waiting time without occupying a thread. The expiration of a timer is notified also by a callback. # Creating a timer Timer interfaces in WFTaskFactory: ~~~cpp using timer_callback_t = std::function; class WFTaskFactory { ... public: static WFTimerTask *create_timer_task(time_t seconds, long nanoseconds, timer_callback_t callback); static WFTimerTask *create_timer_task(const std::string& timer_name, time_t seconds, long nanoseconds, timer_callback_t callback); static int cancel_by_name(const std::string& timer_name) { cancel_by_name(const std::string& timer_name, (size_t)-1); } static int cancel_by_name(const std::string& timer_name, size_t max); }; ~~~ We specify the timing time of a timer through the seconds and nanoseconds parameters. Among them, the value range of nanoseconds is [0,1000000000). When creating a timer, a timer_name can be specified. And we may interrupt a timer by calling **cancel_by_name** with this name later. As a standard workflow task, there is also a user\_data field in the timer task that can be used to transfer some user data. Its starting method is the same as other tasks, and the procedure for adding it into the workflow is also the same. # Canceling a timer A named timer can be interrupted throught WFTaskFacotry::cancel_by_name interface, which will cancel all timers under the name by default. So we provide another cancel interface with the second argument **max** for user to cancel at most **max** timers. Each interface returns the number of timers that was actually canceled. And of course, if no timer under the name, nothing performed and returns 0. You can cancel a timer right after it's created, for example: ~~~cpp #include #include "workflow/WFTaskFactory.h" int main() { WFTimerTask *timer = WFTaskFactory::create_timer_task("test", 10000, 0, [](WFTimerTask *){ printf("timer callback, state = %d, error = %d.\n", task->get_state(), task->get_error()); }); WFTaskFactory::cancel_by_name("test"); timer->start(); getchar(); return 0; } ~~~ This program prints 'timer callback, state = 1, error = 125.",immediately because the timer has be canceled before started, and it will run to callback soon after it's started. And the state code would be WFT_STATE_SYS_ERROR and the error code would be ECANCELED. By the way, create named timer when and only when you may need to cancel it, because it costs more. In other scenarios just use anonymous timer. # Interrupting timer by program exit In [About exit](/docs/en/about-exit.md), you learn that the condition that a main thread can safely end (calls **exit()** or return in the main function) is that all tasks have been run to the callback and no new task is started. Then, there may be a problem, if you wait for the timer to expire, it will take a long time for the program to exit. But in practice, exiting the program can interrupt the timer safely and make it return to the callback. If the timer is interrupted by exiting the program, **get\_state()** will return a WFT\_STATE\_ABORTED state. Of course, if the timer is interrupted by exiting the program, no new tasks can be started. The following program demonstrates crawling one HTTP page at every one second. When all URLs are crawled, the program exits directly without waiting for the timer to return to the callback, and there will be no delay in exiting. ~~~cpp bool program_terminate = false; void timer_callback(WFTimerTask *timer) { mutex.lock(); if (!program_terminate) { WFHttpTask *task; if (urls_to_fetch > 0) { task = WFTaskFactory::create_http_task(...); series_of(timer)->push_back(task); } series_of(timer)->push_back(WFTaskFactory::create_timer_task(1, 0, timer_callback)); } mutex.unlock(); } ... int main() { .... /* all urls done */ mutex.lock(); program_terminate = true; mutex.unlock(); return 0; } ~~~ In the above program, the timer\_callback must check the program\_terminate condition in the lock, otherwise a new task may be started when the program has terminated. workflow-0.11.8/docs/en/about-tlv-message.md000066400000000000000000000114711476003635400207020ustar00rootroot00000000000000# About TLV (Type-Length-Value) format message A TLV message is a message consisting of type, length, and value. Because its format is simple and universal, and it is convenient for nesting and expansion, it is especially suitable for defining communication messages. To facilitate users to implement custom protocols, we have built-in support for TLV messages. # TLV message structure The general TLV structure does not specify the bytes of the Type or Length field. In our protocol, they occupy 4 bytes each (network order). In other words, our message has an 8-byte message header and a Value content of no more than 32GB. We do not specify the meaning of the Type and Value fields. # TLVMessage class Because the definition of TLV format is simple. The interfaces of this TLVMessage are very simple too. ~~~cpp namespace protocol { class TLVMessage : public ProtocolMessage { public: int get_type() const { return this->type; } void set_type(int type) { this->type = type; } std::string *get_value() { return &this->value; } void set_value(std::string value) { this->value = std::move(value); } protected: int type; std::string value; ... }; using TLVRequest = TLVMessage; using TLVResposne = TLVMessage; } ~~~ If users directly use TLV messages for data transmission, they only need to use the above interfaces. Set and get Type and Value respectively. Value is directly returned as ``std::string``, which is convenient for users to move data directly through ``std::move`` when necessary. # An echo server/client example based on TLV message The following code directly starts a server based on TLV messages, and generates a client task through the command line for interaction. ~~~cpp #include #include #include #include "workflow/WFGlobal.h" #include "workflow/WFFacilities.h" #include "workflow/TLVMessage.h" #include "workflow/WFTaskFactory.h" #include "workflow/WFServer.h" using namespace protocol; using WFTLVServer = WFServer; using WFTLVTask = WFNetworkTask; using tlv_callback_t = std::function; WFTLVTask *create_tlv_task(const char *host, unsigned short port, tlv_callback_t callback) { auto *task = WFNetworkTaskFactory::create_client_task( TT_TCP, host, port, 0, std::move(callback)); task->set_keep_alive(60 * 1000); return task; } int main() { WFTLVServer server([](WFTLVTask *task) { *task->get_resp() = std::move(*task->get_req()); }); if (server.start(8888) != 0) { perror("server.start"); exit(1); } auto&& create = [](WFRepeaterTask *)->SubTask * { std::string string; printf("Input string (Ctrl-D to exit): "); std::cin >> string; if (string.empty()) return NULL; auto *task = create_tlv_task("127.0.0.1", 8888, [](WFTLVTask *task) { if (task->get_state() == WFT_STATE_SUCCESS) printf("Server Response: %s\n", task->get_resp()->get_value()->c_str()); else { const char *str = WFGlobal::get_error_string(task->get_state(), task->get_error()); fprintf(stderr, "Error: %s\n", str); } }); task->get_req()->set_value(std::move(string)); return task; }; WFFacilities::WaitGroup wait_group(1); WFRepeaterTask *repeater = WFTaskFactory::create_repeater_task(std::move(create), nullptr); Workflow::start_series_work(repeater, [&wait_group](const SeriesWork *) { wait_group.done(); }); wait_group.wait(); server.stop(); return 0; } ~~~ # To extend TLVMessage In the echo server example above, we directly use the original TLVMessage. However, it is suggested that in specific applications, users can derive TLVMessage. In the derived class, provide a richer interface to set and extract message content, avoid direct manipulation of the original Value field, and form its own secondary protocol. For example, if we implement a JSON protocol, we can: ~~~cpp #include "workflow/json-parser.h" // built-in JSON parser class JsonMessage : public TLVMessage { public: void set_json_value(const json_value_t *val) { this->type = JSON_TYPE; this->json_to_string(val, &this->value); // you have to implement this function } json_value_t *get_json_value() const { if (this->type == JSON_TYPE) return json_parser_parse(this->value.c_str()); // json-parser's interface else return NULL; } }; using JsonRequest = JsonMessage; using JsonResponse = JsonMessage; using JsonServer = WFServer; ~~~ This example is just to illustrate the importance of derivation. In actual applications, derived classes may be far more complicated than this. workflow-0.11.8/docs/en/about-upstream.md000066400000000000000000000561561476003635400203240ustar00rootroot00000000000000# About Upstream In nginx, Upstream represents the load balancing configuration of the reverse proxy. Here, we expand the meaning of Upstream so that it has the following characteristics: 1. Each Upstream is an independent reverse proxy 2. Accessing an Upstream is equivalent to using an appropriate strategy to select one in a group of services/targets/upstream and downstream for access 3. Upstream has load balancing, error handling, circuit breaker and other service governance capabilities 4. For multiple retries of the same request, Upstream can avoid addresses that already tried 5. Different connection parameters can be configured for different addresses through Upstream 6. Dynamically adding/removing target will take effect in real time, which is convenient for any service discovery system ### Advantages of Upstream over domain name DNS resolution Both Upstream and domain name DNS resolution can configure a group of ip to a host, but 1. DNS domain name resolution doesn’t address port number. The service DNS domain names with the same IP and different ports cannot be configured together; but it is possible for Upstream. 2. The set of addresses corresponding to DNS domain name resolution must be ip; while the set of addresses corresponding to Upstream can be ip, domain name or unix-domain-socket 3. Normally, DNS domain name resolution will be cached by operating system or DNS server on the network, and the update time is limited by ttl; while Upstream can be updated in real time and take effect in real time 4. The consumption of DNS domain name is much greater than that of Upstream resolution and selection ### Upstream of Workflow This is a local reverse proxy module, and the proxy configuration is effective for both server and client. Support dynamic configuration and available for any service discovery system. Currently, [workflow-k8s](https://github.com/sogou/workflow-k8s) can be used to acquire Pods information from the API server of Kubernetes. Upstream name does not include port, but upstream request supports specified port. (However, for non-built-in protocols, Upstream name temporarily needs to be added with the port to ensure parsing during construction). Each Upstream is configured with its own independent name UpstreamName, and a set of Addresses is added and set. These Addresses can be: 1. ip4 2. ip6 3. Domain name 4. unix-domain-socket ### Why to replace nginx's Upstream #### Upstream working mode of nginx 1. Supports http/https protocol only 2. Needs to build a nginx service, start the start process occupies socket and other resources 3. The request is sent to nginx first, and nginx forwards the request to remote end, which will increase one more network communication overhead #### Local Upstream working method of workflow 1. Protocol irrelevant, you can even access mysql, redis, mongodb, etc. through upstream 2. You can directly simulate the function of reverse proxy in the process, no need to start other processes or ports 3. The selection process is basic calculation and table lookup, no additional network communication overhead # Use Upstream ### Common interfaces ~~~cpp class UpstreamManager { public: static int upstream_create_consistent_hash(const std::string& name, upstream_route_t consitent_hash); static int upstream_create_weighted_random(const std::string& name, bool try_another); static int upstream_create_manual(const std::string& name, upstream_route_t select, bool try_another, upstream_route_t consitent_hash); static int upstream_delete(const std::string& name); public: static int upstream_add_server(const std::string& name, const std::string& address); static int upstream_add_server(const std::string& name, const std::string& address, const struct AddressParams *address_params); static int upstream_remove_server(const std::string& name, const std::string& address); ... } ~~~ ### Example 1 Random access in multiple targets Configure a local reverse proxy to evenly send all the local requests for **my_proxy.name** to 6 target servers ~~~cpp UpstreamManager::upstream_create_weighted_random( "my_proxy.name", true); // In case of fusing, retry till the available is found or all fuses are blown UpstreamManager::upstream_add_server("my_proxy.name", "192.168.2.100:8081"); UpstreamManager::upstream_add_server("my_proxy.name", "192.168.2.100:8082"); UpstreamManager::upstream_add_server("my_proxy.name", "192.168.10.10"); UpstreamManager::upstream_add_server("my_proxy.name", "test.sogou.com:8080"); UpstreamManager::upstream_add_server("my_proxy.name", "abc.sogou.com"); UpstreamManager::upstream_add_server("my_proxy.name", "abc.sogou.com"); UpstreamManager::upstream_add_server("my_proxy.name", "/dev/unix_domain_scoket_sample"); auto *http_task = WFTaskFactory::create_http_task("http://my_proxy.name/somepath?a=10", 0, 0, nullptr); http_task->start(); ~~~ Basic principles 1. Select a target randomly 2. If try_another is configured as true, one of all surviving targets will be selected randomly 3. Select in the main servers only, the mains and backups of the group where the selected target is located and the backup without group are regarded as valid optional objects ### Example 2 Random access among multiple targets based on weights Configure a local reverse proxy, send all **weighted.random** requests to the 3 target servers based on the weight distribution of 5/20/1 ~~~cpp UpstreamManager::upstream_create_weighted_random( "weighted.random", false); // If you don’t retry in case of fusing, the request will surely fail AddressParams address_params = ADDRESS_PARAMS_DEFAULT; address_params.weight = 5; //weight is 5 UpstreamManager::upstream_add_server("weighted.random", "192.168.2.100:8081", &address_params); // weight is 5 address_params.weight = 20; // weight is 20 UpstreamManager::upstream_add_server("weighted.random", "192.168.2.100:8082", &address_params); // weight is 20 UpstreamManager::upstream_add_server("weighted.random", "abc.sogou.com"); // weight is 1 auto *http_task = WFTaskFactory::create_http_task("http://weighted.random:9090", 0, 0, nullptr); http_task->start(); ~~~ Basic principles 1. According to the weight distribution, randomly select a target, the greater the weight is, the greater the probability is 2. If try_another is configured as true, one of all surviving targets will be selected randomly as per weights. 3. Select in the main servers only, the main and backup of the group where the selected target is located and the backup without group are regarded as valid optional objects ### Example 3 Access among multiple targets based on the framework's default consistent hash ~~~cpp UpstreamManager::upstream_create_consistent_hash( "abc.local", nullptr); // nullptr represents using the default consistent hash function of the framework UpstreamManager::upstream_add_server("abc.local", "192.168.2.100:8081"); UpstreamManager::upstream_add_server("abc.local", "192.168.2.100:8082"); UpstreamManager::upstream_add_server("abc.local", "192.168.10.10"); UpstreamManager::upstream_add_server("abc.local", "test.sogou.com:8080"); UpstreamManager::upstream_add_server("abc.local", "abc.sogou.com"); auto *http_task = WFTaskFactory::create_http_task("http://abc.local/service/method", 0, 0, nullptr); http_task->start(); ~~~ Basic principles 1. Each main server is regarded as 16 virtual nodes 2. The framework will use std::hash to calculate "the address + virtual index of all nodes + the number of times for this address to add into this Upstream" as the node value of the consistent hash 3. The framework will use std::hash to calculate path + query + fragment as a consistent hash data value 4. Choose the value nearest to the surviving node as the target each time 5. For each main, as long as there is a main in surviving group, or there is a backup in surviving group, or there is a surviving no group backup, it is regarded as surviving 6. If weight on AddressParams is set with upstream_add_server(), each main server is regarded as 16 * weight virtual nodes. This is suitable for weighted consistent hash or shrinking the standard deviation of consistent hash ### Example 4 User-defined consistent hash function ~~~cpp UpstreamManager::upstream_create_consistent_hash( "abc.local", [](const char *path, const char *query, const char *fragment) -> unsigned int { unsigned int hash = 0; while (*path) hash = (hash * 131) + (*path++); while (*query) hash = (hash * 131) + (*query++); while (*fragment) hash = (hash * 131) + (*fragment++); return hash; }); UpstreamManager::upstream_add_server("abc.local", "192.168.2.100:8081"); UpstreamManager::upstream_add_server("abc.local", "192.168.2.100:8082"); UpstreamManager::upstream_add_server("abc.local", "192.168.10.10"); UpstreamManager::upstream_add_server("abc.local", "test.sogou.com:8080"); UpstreamManager::upstream_add_server("abc.local", "abc.sogou.com"); auto *http_task = WFTaskFactory::create_http_task("http://abc.local/sompath?a=1#flag100", 0, 0, nullptr); http_task->start(); ~~~ Basic principles 1. The framework will use a user-defined consistent hash function as the data value 2. The rest is the same as the above principles ### Example 5 User-defined selection strategy ~~~cpp UpstreamManager::upstream_create_manual( "xyz.cdn", [](const char *path, const char *query, const char *fragment) -> unsigned int { return atoi(fragment); }, true, // If a blown target is selected, a second selection will be made nullptr); // nullptr represents using the default consistent hash function of the framework in the second selection UpstreamManager::upstream_add_server("xyz.cdn", "192.168.2.100:8081"); UpstreamManager::upstream_add_server("xyz.cdn", "192.168.2.100:8082"); UpstreamManager::upstream_add_server("xyz.cdn", "192.168.10.10"); UpstreamManager::upstream_add_server("xyz.cdn", "test.sogou.com:8080"); UpstreamManager::upstream_add_server("xyz.cdn", "abc.sogou.com"); auto *http_task = WFTaskFactory::create_http_task("http://xyz.cdn/sompath?key=somename#3", 0, 0, nullptr); http_task->start(); ~~~ Basic principles 1. The framework first determines the selection in the main server list according to the normal selection function provided by the user and then get the modulo 2. For each main server, as long as there is a main server in surviving group, or there is a backup in surviving group, or there is a surviving no group backup, it is regarded as surviving 3. If the selected target no longer survives and try_another is set as true, a second selection will be made using consistent hash function 4. If the second selection is triggered, the consistent hash will ensure that a survival target will be selected, unless all machines are blown ### Example 6 Simple main-backup mode ~~~cpp UpstreamManager::upstream_create_weighted_random( "simple.name", true);//One main, one backup, nothing is different in this item AddressParams address_params = ADDRESS_PARAMS_DEFAULT; address_params.server_type = 0; /* 1 for main server */ UpstreamManager::upstream_add_server("simple.name", "main01.test.ted.bj.sogou", &address_params); // main address_params.server_type = 1; /* 0 for backup server */ UpstreamManager::upstream_add_server("simple.name", "backup01.test.ted.gd.sogou", &address_params); //backup auto *http_task = WFTaskFactory::create_http_task("http://simple.name/request", 0, 0, nullptr); auto *redis_task = WFTaskFactory::create_redis_task("redis://simple.name/2", 0, nullptr); redis_task->get_req()->set_query("MGET", {"key1", "key2", "key3", "key4"}); (*http_task * redis_task).start(); ~~~ Basic principles 1. The main-backup mode does not conflict with any of the modes shown above, and it can take effect at the same time 2. The number of main/backup is independent of each other and there is no limit. All main servers are coequal to each other, and all backup servers are coequal to each others, but main and backup are not coequal to each other. 3. As long as a main server is alive, the request will always use a main server. 4. If all main servers are blown, backup server will take over the request as a substitute target until any main server works well again 5. In every strategy, surviving backup can be used as the basis for the survival of main ### Example 7 Main-backup + consistent hash + grouping ~~~cpp UpstreamManager::upstream_create_consistent_hash( "abc.local", nullptr);//nullptr represents using the default consistent hash function of the framework AddressParams address_params = ADDRESS_PARAMS_DEFAULT; address_params.server_type = 0; address_params.group_id = 1001; UpstreamManager::upstream_add_server("abc.local", "192.168.2.100:8081", &address_params);//main in group 1001 address_params.server_type = 1; address_params.group_id = 1001; UpstreamManager::upstream_add_server("abc.local", "192.168.2.100:8082", &address_params);//backup for group 1001 address_params.server_type = 0; address_params.group_id = 1002; UpstreamManager::upstream_add_server("abc.local", "backup01.test.ted.bj.sogou", &address_params);//main in group 1002 address_params.server_type = 1; address_params.group_id = 1002; UpstreamManager::upstream_add_server("abc.local", "backup01.test.ted.gd.sogou", &address_params);//backup for group 1002 address_params.server_type = 1; address_params.group_id = -1; UpstreamManager::upstream_add_server("abc.local", "test.sogou.com:8080", &address_params);//backup with no group mean backup for all groups and no group UpstreamManager::upstream_add_server("abc.local", "abc.sogou.com");//main, no group auto *http_task = WFTaskFactory::create_http_task("http://abc.local/service/method", 0, 0, nullptr); http_task->start(); ~~~ Basic principles 1. Group number -1 means no group, this kind of target does not belong to any group 2. The main servers without a group are coequal to each other, and they can even be regarded as one group. But they are isolated from the other main servers with a group 3. A backup without a group can serve as a backup for any group target of Global/any target without a group 4. The group number can identify which main and backup are working together 5. The backups of different groups are isolated from each other, and they serve the main servers of their own group only 6. Add the default group number -1 of the target, and the type is main ### Example 8 NVSWRR selection weighting strategy ~~~cpp UpstreamManager::upstream_create_vnswrr("nvswrr.random"); AddressParams address_params = ADDRESS_PARAMS_DEFAULT; address_params.weight = 3;//weight is 3 UpstreamManager::upstream_add_server("nvswrr.random", "192.168.2.100:8081", &address_params);//weight is 3 address_params.weight = 2;//weight is 2 UpstreamManager::upstream_add_server("nvswrr.random", "192.168.2.100:8082", &address_params);//weight is 2 UpstreamManager::upstream_add_server("nvswrr.random", "abc.sogou.com");//weight is 1 auto *http_task = WFTaskFactory::create_http_task("http://nvswrr.random:9090", 0, 0, nullptr); http_task->start(); ~~~ 1. The virtual node initialization sequence is selected according to the [SWRR algorithm](https://github.com/nginx/nginx/commit/52327e0627f49dbda1e8db695e63a4b0af4448b1) 2. The virtual nodes are initialized in batches during operation to avoid intensive computing concentration. After each batch of virtual nodes is used up, the next batch of virtual node lists can be initialized. 3. It has both the smooth and scattered characteristics of [SWRR algorithm](https://github.com/nginx/nginx/commit/52327e0627f49dbda1e8db695e63a4b0af4448b1) and the time complexity of O(1) 4. For specific details of the algorithm, see tengine(https://github.com/alibaba/tengine/pull/1306) # Upstream selection strategy When the URIHost of the url that initiates the request is filled with UpstreamName, it is regarded as a request to the Upstream corresponding to the name, and then it will be selected from the set of Addresses recorded by the Upstream: 1. Weight random strategy: selection randomly according to weight 2. Consistent hash strategy: The framework uses a standard consistent hashing algorithm, and users can define the consistent hash function consistent_hash for the requested uri 3. Manual strategy: make definite selection according to the select function that user provided for the requested uri, if the blown target is selected: **a.** If try_another is false, this request will return to failure **b.** If try_another is true, the framework uses standard consistent hash algorithm to make a second selection, and the user can define the consistent hash function consistent_hash for the requested uri 4. Main-backup strategy: According to the priority of main first, backup next, select a main server as long as it can be used. This strategy can take effect concurrently with any of [1], [2], and [3], and they influence each other. Round-robin/weighted-round-robin: regarded as equivalent to [1], not available for now The framework recommends common users to use strategy [2], which can ensure that the cluster has good fault tolerance and scalability For complex scenarios, advanced users can use strategy [3] to customize complex selection logic # Address attribute ~~~cpp struct EndpointParams { size_t max_connections; int connect_timeout; int response_timeout; int ssl_connect_timeout; bool use_tls_sni; }; static constexpr struct EndpointParams ENDPOINT_PARAMS_DEFAULT = { .max_connections = 200, .connect_timeout = 10 * 1000, .response_timeout = 10 * 1000, .ssl_connect_timeout = 10 * 1000, .use_tls_sni = false, }; struct AddressParams { struct EndpointParams endpoint_params; unsigned int dns_ttl_default; unsigned int dns_ttl_min; unsigned int max_fails; unsigned short weight; int server_type; /* 0 for main and 1 for backup. */ int group_id; }; static constexpr struct AddressParams ADDRESS_PARAMS_DEFAULT = { .endpoint_params = ENDPOINT_PARAMS_DEFAULT, .dns_ttl_default = 12 * 3600, .dns_ttl_min = 180, .max_fails = 200, .weight = 1, // only for main of UPSTREAM_WEIGHTED_RANDOM .server_type = 0, .group_id = -1, }; ~~~ Each address can be configured with custom parameters: * Max_connections, connect_timeout, response_timeout, ssl_connect_timeout of EndpointParams: connection-related parameters * dns_ttl_default: The default ttl in the dns cache in seconds, and the default value is 12 hours. The dns cache is for the current process, that is, the process will disappear after exiting, and the configuration is only valid for the current process * dns_ttl_min: The shortest effective time of dns in seconds, and the default value is 3 minutes. It is used to decide whether to perform dns again when communication fails and retry. * max_fails: the number of [continuous] failures that triggered fusing (Note: each time the communication is successful, the count will be cleared) * Weight: weight, the default value is 1, which is only valid for main. It is used for Upstream weighted random strategy selection and consistent hash strategy selection, the larger the weight is, the easier it is to be selected. * server_type: main/backup configuration, main by default (server_type=0). At any time, the main servers in the same group are always at higher priority than backups * group_id: basis for grouping, the default value is -1. -1 means no grouping (free). A free backup can be regarded as backup to any main server. Any backup with group is always at higher priority than any free backup. # About fuse ## MTTR Mean time to repair (MTTR) is the average value of the repair time when the product changes from a fault state to a working state. ## Service avalanche effect Service avalanche effect is a phenomenon in which "service caller failure" (result) is caused by "service provider's failure" (cause), and the unavailability is amplified gradually/level by level If it is not controlled effectively, the effect will not converge, but will be amplified geometrically, just like an avalanche, that’s why it is called avalanche effect Description of the phenomenon: at first it is just a small service or module abnormality/timeout, causing abnormality/timeout of other downstream dependent services, then causing a chain reaction, eventually leading to paralysis of most or all services As the fault is repaired, the effect will disappear, so the duration of the effect is usually equal to MTTR ## Fuse mechanism When the error or abnormal touch of a certain target meets the preset threshold condition, the target is temporarily considered unavailable, and the target is removed, namely fuse is started and enters the fuse period After the fuse duration reaches MTTR duration, turn into half-open status, (attempt to) restore the target If all targets are found fused whenever recovering one target, all targets will be restored at the same time Fuse mechanism strategy can effectively prevent avalanche effect ## Upstream fuse protection mechanism MTTR=30 seconds, which is temporarily not configurable, but we will consider opening it to be configured by users in the future. When the number of consecutive failures of a certain Address reaches the set upper limit (200 times by default), this Address will be blown, MTTR=30 seconds. During the fusing period, once the Address is selected by the strategy, Upstream will decide whether to try other Addresses and how to try according to the specific configuration Please note that if one of the following 1-4 scenarios is met, the communication task will get an error of WFT_ERR_UPSTREAM_UNAVAILABLE = 1004: 1. Weight random strategy, all targets are in the fusing period 2. Consistent hash strategy, all targets are in the fusing period 3. Manual strategy and try_another==true, all targets are in the fusing period 4. Manual strategy and try_another==false, and all the following three conditions shall meet at the same time: 1). The main selected by the select function is in the fusing period, and all free devices are in the fusing period 2). The main is a free main, or other targets in the group where the main is located are all in the fusing period 3). All free devices are in the fusing period # Upstream port priority 1. Priority is given to the port number explicitly configured on the Upstream Address 2. If not, select the port number explicitly configured in the request url 3. If none, use the default port number of the protocol ~~~cpp Configure UpstreamManager::upstream_add_server("my_proxy.name", "192.168.2.100:8081"); Request http://my_proxy.name:456/test.html => http://192.168.2.100:8081/test.html Request http://my_proxy.name/test.html => http://192.168.2.100:8081/test.html ~~~ ~~~cpp Configure UpstreamManager::upstream_add_server("my_proxy.name", "192.168.10.10"); Request http://my_proxy.name:456/test.html => http://192.168.10.10:456/test.html Request http://my_proxy.name/test.html => http://192.168.10.10:80/test.html ~~~ workflow-0.11.8/docs/en/tutorial-01-wget.md000066400000000000000000000117071476003635400203720ustar00rootroot00000000000000# Creating your first task: wget # Sample code [tutorial-01-wget.cc](/tutorial/tutorial-01-wget.cc) # About wget wget reads HTTP/HTTPS URLs from stdin, crawls the webpages and then print the content to stdout. It also outputs the HTTP headers of the request and the response to stderr. For convenience, wget exits with Ctrl-C, but it will ensure that all resources are completely released first. # Creating and starting an HTTP task ~~~cpp WFHttpTask *task = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX, wget_callback); protocol::HttpRequest *req = task->get_req(); req->add_header_pair("Accept", "*/*"); req->add_header_pair("User-Agent", "Wget/1.14 (gnu-linux)"); req->add_header_pair("Connection", "close"); task->start(); pause(); ~~~ **WFTaskFactory::create\_http\_task()** generates an HTTP task. In [WFTaskFactory.h](/src/factory/WFTaskFactory.h), the prototype is defined as follows: ~~~cpp WFHttpTask *create_http_task(const std::string& url, int redirect_max, int retry_max, http_callback_t callback); ~~~ The first few parameters are self-explanatory. **http\_callback\_t** is the callback of an HTTP task, which is defined below: ~~~cpp using http_callback_t = std::function; ~~~ To put it simply, it’s the funtion that has **Task** as one parameter and does not return any value. You can pass NULL to this callback, indicating that there is no callback. The callback in all tasks follows the same rule. Please note that all factory functions do not return failure, so even if the URL is illegal, don't worry that the task is a null pointer. All errors are handled in the callback. You can use **task->get\_req()** to get the request of the task. The default method is GET via HTTP/1.1 on long connections. The framework automatically adds request\_uri, Host and other parameters. The framework will add other HTTP header fields automatically according to the actual requirements, including Content-Length or Connection before sending the request. You may also use **add\_header\_pair()** to add your own header. For more interfaces on HTTP messages, please see [HttpMessage.h](/src/protocol/HttpMessage.h). **task->start()** starts the task. It’s non-blocking and will not fail. Then the callback of the task will be called. As it’s an asynchronous task, obviously you cannot use the task pointer after **start()**. To make the example as simple as possible, call **pause()** after **start()** to prevent the program from exiting. You can press Ctrl-C to exit the program. # Handling crawled HTTP results This example demonstrates how to handle the results with a general function. Of course, **std::function** supports more features. ~~~cpp void wget_callback(WFHttpTask *task) { protocol::HttpRequest *req = task->get_req(); protocol::HttpResponse *resp = task->get_resp(); int state = task->get_state(); int error = task->get_error(); // handle error states ... std::string name; std::string value; // print request to stderr fprintf(stderr, "%s %s %s\r\n", req->get_method(), req->get_http_version(), req->get_request_uri()); protocol::HttpHeaderCursor req_cursor(req); while (req_cursor.next(name, value)) fprintf(stderr, "%s: %s\r\n", name.c_str(), value.c_str()); fprintf(stderr, "\r\n"); // print response header to stderr ... // print response body to stdin void *body; size_t body_len; resp->get_parsed_body(&body, &body_len); // always success. fwrite(body, 1, body_len, stdout); fflush(stdout); } ~~~ In this callback, the task is generated by the factory. You can use **task->get\_state()** and **task->get\_error()** to obtain the running status and the error code of the task respectively. Let's skip the error handling first. Use **task->get\_resp()** to get the response of the task, which is slightly different from the request, as they are both derived from HttpMessage. Then, use the HttpHeaderCursor to scan the headers of the request and the response. [HttpUtil.h](/src/protocol/HttpUtil.h) contains the definition of the Cursor. ~~~cpp class HttpHeaderCursor { public: HttpHeaderCursor(const HttpMessage *message); ... void rewind(); ... bool next(std::string& name, std::string& value); bool find(const std::string& name, std::string& value); ... }; ~~~ There should be no doubt about the use of this cursor. The next line **resp->get\_parsed\_body()** obtains the HTTP body of the response. This call always returns true when the task is successful, and the body points to the data area. The call gets the raw HTTP body, and does not decode the chunk. If you want to decode the chunk, you can use the HttpChunkCursor in [HttpUtil.h](/src/protocol/HttpUtil.h). In addition, **find()** will change the pointer inside the cursor. If you want to iterate over the header after you use **find()**, please use **rewind()** to return to the cursor header. workflow-0.11.8/docs/en/tutorial-02-redis_cli.md000066400000000000000000000147521476003635400213650ustar00rootroot00000000000000# Implementing Redis set and get: redis\_cli # Sample code [tutorial-02-redis\_cli.cc](/tutorial/tutorial-02-redis_cli.cc) # About redis\_cli The program reads the Redis server address and a key/value pair from the command line. Then execute SET to write this KV pair and then read them to verify that the writing is sucessful. Command: ./redis_cli \ \ \ For the sake of simplicity, press Ctrl-C to exit the program. # Format of Redis URL redis://:password@host:port/dbnum?query#fragment If SSL is used, use: rediss://:password@host:port/dbnum?query#fragment password is optional. The default port is 6379; the default dbnum is 0, and its range is from 0 to 15. query and fragment are not used in the factory and you can define them by yourself. For example, if you want to use upstream selection , you can define your own query and fragment. For relevant details, please see upstream documents. Sample Redis URL: redis://127.0.0.1/ redis://:12345678@redis.some-host.com/1 # Creating and starting a Redis task Creating a Redis task is almost the same as creating an HTTP task. The only difference is the omission of redirect\_max. ~~~cpp using redis_callback_t = std::function; WFRedisTask *create_redis_task(const std::string& url, int retry_max, redis_callback_t callback); ~~~ In this example, we want to store some user data in the Redis task, including URL and key, and use them in the callback. We can use **std::function** to bind the parameters. Here we use **void \*user\_data** pointer in the task. The pointer is a public member of the task. ~~~cpp struct tutorial_task_data { std::sring url; std::string key; }; ... struct tutorial_task_data data; data.url = argv[1]; data.key = argv[2]; WFRedisTask *task = WFTaskFactory::create_redis_task(data.url, RETRY_MAX, redis_callback); protocol::RedisRequest *req = task->get_req(); req->set_request("SET", { data.key, argv[3] }); task->user_data = &data; task->start(); pause(); ~~~ Similar to **get\_req()** in an HTTP task, **get\_req()** in an Redis task returns the Redis request for that task. You can see the functions of RedisRequest in [RedisMessage.h](/src/protocol/RedisMessage.h), where **set\_request** is used to set Redis command. ~~~cpp void set_request(const std::string& command, const std::vector& params); ~~~ There is little doubt about this interface for people who frequently use Redis. However, please note that you cannot use SELECT and AUTH commands in the request. The reason is that as you can't specify the connection every time you send a request and the next request after SELECT may not be initiated on the same connection, this command is meaningless. Please specify the database name and password in the Redis URL. And the URL of every request must contain these data. In addition, this redis client fully supports redis cluster mode. The client will process MOVED and ASK response, and redirect correctly. # Handling results After you successfully run the SET command, send the GET command to verify the writing. GET also uses the same callback. Therefore, the function will determine the source command of the results. Let's skip the error handling first again. ~~~cpp void redis_callback(WFRedisTask *task) { protocol::RedisRequest *req = task->get_req(); protocol::RedisResponse *resp = task->get_resp(); int state = task->get_state(); int error = task->get_error(); protocol::RedisValue val; ... resp->get_result(val); std::string cmd; req->get_command(cmd); if (cmd == "SET") { tutorial_task_data *data = (tutorial_task_data *)task->user_data; WFRedisTask *next = WFTaskFactory::create_redis_task(data->url, RETRY_MAX, redis_callback); next->get_req()->set_request("GET", { data->key }); series_of(task)->push_back(next); fprintf(stderr, "Redis SET request success. Trying to GET...\n"); } else /* if (cmd == 'GET') */ { // print the GET result ... fprintf(stderr, "Finished. Press Ctrl-C to exit.\n"); } } ~~~ RedisValue is the results of one Redis request. You can also see the interface in [RedisMessage.h](/src/protocol/RedisMessage.h). You need to pay special attention to the callback in the line **series\_of(task)->push\_back(next)**. It`s the firt time we use the functions of Workflow. Here **next** means the Redis task we are about to start: run GET operation. We do not use **next->start()** to start the task. We use **push\_back** to append the next task to the end of the current task queue instead. The difference between the two methods is: * When a task is initiated by **start**, the task is started immediately; when a task is **push\_back** to the queue, the **next** task is initiated after the callback. * The obvious advantage is that the **push\_back** method can ensure that the log printing is not chaotic. Otherwise, if you use the **next->start()**, the \"Finished.\" in the sample may be printed out first. * If you use **start** to initiate the next task, the current task series ends and the next task will initiate a new series. * You can set a callback for a series. For the sake of simplicity, the sample omit it. * In the parallel tasks, a series is a branch of the parallel task. If the series ends, it is considered that the brand also ends. The following tutorials demonstrates how to use parallel tasks. In a word, if you want to start the next task after one task, you usually use **push\_back** operation (in some cases, **push\_front** may be used). **series\_of()** is a very important call and it is a global function that does not belong to any class. [Workflow.h](/src/factory/Workflow.h#L140) contains its definition and implementation. ~~~cpp static inline SeriesWork *series_of(const SubTask *task) { return (SeriesWork *)task->get_pointer(); } ~~~ All tasks are derived from SubTask. And any running task must belong to one series. You can call **series\_of()** to get the series of a task. **push\_back** is a function in the SeriesWork class, which is used to append a task to the end of the series. **push\_front** is a similar function. In the sample, you can use either function. ~~~cpp class SeriesWork { ... public: void push_back(SubTask *task); void push_front(SubTask *task); ... } ~~~ SeriesWork class plays an important role in our system. In the next tutorial, you will learn more functions in SeriesWork. workflow-0.11.8/docs/en/tutorial-03-wget_to_redis.md000066400000000000000000000073621476003635400222660ustar00rootroot00000000000000# More features about series: wget\_to\_redis # Sample code [tutorial-03-wget\_to\_redis.cc](/tutorial/tutorial-03-wget_to_redis.cc) # About wget\_to\_redis The program reads one HTTP URL and one redis URL from the command line, crawls the HTTP web page and saves the content to Redis, with the key as the HTTP URL. Differing from the other two examples, we add a wake-up mechanism. The program can automatically exit and users are not required to press Ctrl-C. # Creating and configuring an HTTP task Similar to the previous example, in this example, we also executes two requests in series. The biggest difference is that we inform the main thread that the execution of the task has finished and quit normally. In addition, we add two more calls to limit the size of the crawled HTTP response content and the maximum time to receive the reply. ~~~cpp WFHttpTask *http_task = WFTaskFactory::create_http_task(...); ... http_task->get_resp()->set_size_limit(20 * 1024 * 1024); http_task->set_receive_timeout(30 * 1000); ~~~ **set\_size\_limit()** is a function in HttpMessage.It is used to limit the packet size of incoming HTTP message. Actually this interface is required in all protocol messages. **set\_receive\_timeout()** sets the timeout for receiving data, in milliseconds. The above code limits the size of the HTTP message to no more than 20M and the time for receiving the complete message to no more than 30 seconds. You can learn more about timeout configuration in the following documents. # Creating and starting a SeriesWork In the previous two examples, we call **task->start()** directly to start the first task. The actual procedure in **task->start()** is: create a SeriesWork with the task as the head and then start the series. In [WFTask.h](/src/factory/WFTask.h), you can see the implemetation of **start**. ~~~cpp template class WFNetWorkTask : public CommRequest { public: void start() { assert(!series_of(this)); Workflow::start_series_work(this, nullptr); } ... }; ~~~ We want to set a callback for that series and add some context. Therefore, instead of using the **start** interface of the task, we create our own series. You cannot new, delete or inherit a SeriesWork. It can only be generated through **Workflow::create\_series\_work()** interface. In [Workflow.h](/src/factory/Workflow.h), generally we use the following call: ~~~cpp using series_callback_t = std::function; class Workflow { public: static SeriesWork *create_series_work(SubTask *first, series_callback_t callback); }; ~~~ In the sample code, our usage is as follows: ~~~cpp struct tutorial_series_context { std::string http_url; std::string redis_url; size_t body_len; bool success; }; ... struct tutorial_series_context context; ... SeriesWork *series = Workflow::create_series_work(http_task, series_callback); series->set_context(&context); series->start(); ~~~ In the previous example, we use the pointer **void \*user\_data** in the task to save the context. However, in this example, we put the context in the series, which is more reasonable. The series is a complete task chain, and all tasks can obtain and modify the context. The callback function of the series is called after all the tasks in that series are finished. Here, we simply use a lamda function to print the running results and wake up the main thread. # Other work There's nothing special left. After the HTTP crawling is successful, a Redis task is started to write the data into the database. If the crawling fails or the length of the HTTP body is 0, the Redis task will not be started. In any case, the program can exit normally after all tasks are finished, because all tasks are in the same series. workflow-0.11.8/docs/en/tutorial-04-http_echo_server.md000066400000000000000000000174141476003635400227730ustar00rootroot00000000000000# First server: http\_echo\_server # Sample code [tutorial-04-http\_echo\_server.cc](/tutorial/tutorial-04-http_echo_server.cc) # About http\_echo\_server It is an HTTP server that returns an HTML page, which displays the header data in the HTTP request sent by the browser. The log of the program contains the client address and the sequence of the request (the number of requests on the current connection). When 10 requests are completed on the same connection, the server actively closes the connection. The program exits normally after users press Ctrl-C, and all resources are completely reclaimed. # Creating and starting an HTTP server In this example, we use the default parameters of an HTTP server. It is very simple to create and start an HTTP server. ~~~cpp WFHttpServer server(process); port = atoi(argv[1]); if (server.start(port) == 0) { pause(); server.stop(); } ... ~~~ The procedure is too simple to explain. Please note that the start process is non-blocking, so please pause the program. Obviously you can start several server objects and then pause. After a server is started, you can use **stop()** interface to shut down the server at any time. Stopping a server is non-violent and will be done until all the processing requests in the server are completed. Therefore, **stop** is a blocking operation. If non-blocking shutdown is required, please use **shutdown+wait\_finish** interface. There are several overloaded functions with **start()**. [WFServer.h](/src/server/WFServer.h) contains the following interfaces: ~~~cpp class WFServerBase { public: /* To start TCP server. */ int start(unsigned short port); int start(int family, unsigned short port); int start(const char *host, unsigned short port); int start(int family, const char *host, unsigned short port); int start(const struct sockaddr *bind_addr, socklen_t addrlen); /* To start an SSL server */ int start(unsigned short port, const char *cert_file, const char *key_file); int start(int family, unsigned short port, const char *cert_file, const char *key_file); int start(const char *host, unsigned short port, const char *cert_file, const char *key_file); int start(int family, const char *host, unsigned short port, const char *cert_file, const char *key_file); int start(const struct sockaddr *bind_addr, socklen_t addrlen, const char *cert_file, const char *key_file); /* For graceful restart or multi-process server. */ int serve(int listen_fd); int serve(int listen_fd, const char *cert_file, const char *key_file); /* Get the listening address. Used when started a server on a random port. */ int get_listen_addr(struct sockaddr *addr, socklen_t *addrlen) const; }; ~~~ There interfaces are easy to understand. If the **port** number is zero, the server will be started on a random port, and you may need to call **get_listen_addr** to abtain the actual listening address (mainly for the actual port) after the server is started. When you start an SSL server, the cert\_file and key\_file should be in PEM format. The last two **serve()** interfaces have the parameter **listen\_fd**, which is used for graceful restart or for building a simple non-TCP (such as SCTP) server. Please note that one server object corresponds to one **listen\_fd**. If the server is running on both IPv4 and IPv6 protocols, you should: ~~~cpp { WFHttpServer server_v4(process); WFHttpServer server_v6(process); server_v4.start(AF_INET, port); server_v6.start(AF_INET6, port); ... // now stop... server_v4.shutdown(); /* shutdown() is nonblocking */ server_v6.shutdown(); server_v4.wait_finish(); server_v6.wait_finish(); } ~~~ In the above code, the two servers cannot share the connection counter. Therefore, it is recommended to start the IPv6 server only, because the IPv6 server can accept IPv4 connection. # Business logic of an HTTP echo server When you build an HTTP server, you pass a process parameter, which is also an **std::function**, as defined below: ~~~cpp using http_process_t = std::function; using WFHttpServer = WFServer; template<> WFHttpServer::WFServer(http_process_t proc) : WFServerBase(&HTTP_SERVER_PARAMS_DEFAULT), process(std::move(proc)) { } ~~~ Actually, the type of **http\_proccess\_t** and the type of **http\_callback\_t** are exactly the same. Both are used to handle WFHttpTask. The job of the server is to populate the response based on the request. Similarly, we use an ordinary function to implement the process. The process iterates over the HTTP header of the request line by line and then writes them into an HTML page. ~~~cpp void process(WFHttpTask *server_task) { protocol::HttpRequest *req = server_task->get_req(); protocol::HttpResponse *resp = server_task->get_resp(); long seq = server_task->get_task_seq(); protocol::HttpHeaderCursor cursor(req); std::string name; std::string value; char buf[8192]; int len; /* Set response message body. */ resp->append_output_body_nocopy("", 6); len = snprintf(buf, 8192, "

%s %s %s

", req->get_method(), req->get_request_uri(), req->get_http_version()); resp->append_output_body(buf, len); while (cursor.next(name, value)) { len = snprintf(buf, 8192, "

%s: %s

", name.c_str(), value.c_str()); resp->append_output_body(buf, len); } resp->append_output_body_nocopy("", 7); /* Set status line if you like. */ resp->set_http_version("HTTP/1.1"); resp->set_status_code("200"); resp->set_reason_phrase("OK"); resp->add_header_pair("Content-Type", "text/html"); resp->add_header_pair("Server", "Sogou WFHttpServer"); if (seq == 9) /* no more than 10 requests on the same connection. */ resp->add_header_pair("Connection", "close"); // print log ... } ~~~ You have learned most of the HttpMessage related operations. The only new operation here is **append\_output\_body()**. Obviously, it is not very efficient for the users to generate a complete HTTP body and pass it to the framework. The user only needs to call the **append** interface to append the discrete data to the message block by block. **append\_output\_body()** operation will move the data, and another interface with the suffix **\_nocopy** will directly use the reference to the pointer. Please do not make it point to the local variables when you use it. [HttpMessage.h](../src/protocol/HttpMessage.h) contains the declaration of relevant calls. ~~~cpp class HttpMessage { public: bool append_output_body(const void *buf, size_t size); bool append_output_body_nocopy(const void *buf, size_t size); ... bool append_output_body(const std::string& buf); }; ~~~ Once again, please note that when you use **append\_output\_body\_nocopy()**, the lifecycle of the data referenced by the buf must at least be extended to the callback of the task. Another variable seq in the function is obtained by **server\_task->get\_task\_seq()**, which indicates the number of requests on the current connection, starting from 0. In the program, the connection is forcibly closed after 10 requests are completed, thus: ~~~cpp if (seq == 9) /* no more than 10 requests on the same connection. */ resp->add_header_pair("Connection", "close"); ~~~ You can also use **task->set\_keep\_alive()** to close the connection. However, for the connection using HTTP protocol, it is recommended to set the “close” option in HTTP header. In this example, because the response page is very small, we didn't pay attention to the reply status. In the next tutorial **http\_proxy**, you will learn how to get the reply status. workflow-0.11.8/docs/en/tutorial-05-http_proxy.md000066400000000000000000000250471476003635400216520ustar00rootroot00000000000000# Asynchronous server: http\_proxy # Sample code [tutorial-05-http\_proxy.cc](/tutorial/tutorial-05-http_proxy.cc) # About http\_proxy It is an HTTP proxy server. You can use it in a browser after proper configuration. It supports all HTTP methods. As HTTPS proxy follows different principles, this example does not support HTTPS proxy. You can only browse HTTP websites. In the implementation, this proxy must crawl the entire HTTP page and then forward it. Therefore, there will be noticeable latency when you upload/download a large file. # Changing server configuration In the previous example, we use the default parameters of an HTTP server. In this tutorial, we will made some changes and limit the size of the request so as to prevent malicious attack. ~~~cpp int main(int argc, char *argv[]) { ... struct WFServerParams params = HTTP_SERVER_PARAMS_DEFAULT; params.request_size_limit = 8 * 1024 * 1024; WFHttpServer server(¶ms, process); if (server.start(port) == 0) { pause(); server.stop(); } else { perror("cannot start server"); exit(1); } return 0; } ~~~ Unlike the previous example, we pass an additional parameter to the server struct. Let’s see the configuration items in the HTTP server. In [WFHttpServer.h](/src/server/WFHttpServer.h), the default parameters for an HTTP server include: ~~~cpp static constexpr struct WFServerParams HTTP_SERVER_PARAMS_DEFAULT = { .transport_type = TT_TCP, .max_connections = 2000, .peer_response_timeout = 10 * 1000, .receive_timeout = -1, .keep_alive_timeout = 60 * 1000, .request_size_limit = (size_t)-1, .ssl_accept_timeout = 10 * 1000, }; ~~~ **transport\_type**: the transport layer protocol. Besides the default type TT_TCP, you may specify TT_UDP, or TT_SCTP on Linux platform. **max\_connections**: the maximum number of connections is 2000. When it is exceeded, the least recently used keep-alive connection will be closed. If there is no keep-alive connection, the server will refuse new connections. **peer\_response\_timeout**: set the maximum duration for reading or sending out a block of data. The default setting is 10 seconds. **receive\_timeout**: set the maximum duration for receiving a complete request; -1 means unlimited time. **keep\_alive\_timeout**: set the maximum duration for maintaining a connection. The default setting is 1 minute. **request\_size\_limit**: set the maximum size of a request packet. The default setting is unlimited packet size. **ssl\_accept\_timeout**: set the maximum duration for an SSL handshake. The default setting is 10 seconds. There is no **send\_timeout** in the parameters. **send\_timeout** sets the timeout for sending a complete response. This parameter should be determined according to the size of the response packet. # Business logic of a proxy server Essentially, this proxy server forwards a user's request intactly to the corresponding web server, and then forwards the reply from the web server intactly to the user. In the request sent by a browser to the proxy, the Request URL contains scheme, host and port, which should be removed before forwarding. For example, when the browser visits `http://www.sogou.com/`, the first line of the request sent by the browser to the proxy is: `GET` `http://www.sogou.com/` `HTTP/1.1` which should be rewritten as: `GET` `/` `HTTP/1.1` ~~~cpp void process(WFHttpTask *proxy_task) { auto *req = proxy_task->get_req(); SeriesWork *series = series_of(proxy_task); WFHttpTask *http_task; /* for requesting remote webserver. */ tutorial_series_context *context = new tutorial_series_context; context->url = req->get_request_uri(); context->proxy_task = proxy_task; series->set_context(context); series->set_callback([](const SeriesWork *series) { delete (tutorial_series_context *)series->get_context(); }); http_task = WFTaskFactory::create_http_task(req->get_request_uri(), 0, 0, http_callback); const void *body; size_t len; /* Copy user's request to the new task's reuqest using std::move() */ req->set_request_uri(http_task->get_req()->get_request_uri()); req->get_parsed_body(&body, &len); req->append_output_body_nocopy(body, len); *http_task->get_req() = std::move(*req); /* also, limit the remote webserver response size. */ http_task->get_resp()->set_size_limit(200 * 1024 * 1024); *series << http_task; } ~~~ The above contains the entire content of the process. It first parses the struct of an HTTP request sent by a web server. **req->get\_request\_uri()** is used to get the complete URL of the request sent by a browser. And then build a HTTP task to the server based on this URL. Both the retry times and the redirection times of this HTTP task is 0, because the redirection is handled by the browser and the browser will be resend the request when it meets 302, etc. ~~~cpp req->set_request_uri(http_task->get_req()->get_request_uri()); req->get_parsed_body(&body, &len); req->append_output_body_nocopy(body, len); *http_task->get_req() = std::move(*req); ~~~ In fact, the above four lines generates a HTTP request to the web server. req is the received HTTP request, and it will be moved directly to the new request via **std::move()**. The first line removes the `http://host:port` in the request\_uri and keeps the part after the path. The second line and the third line specify the parsed HTTP body as the HTTP body for output. The reason for this operation is that in the HttpMessage implementation, the http body obtained by parsing and the http body to send out are two fields, so we need to simply set it here, without copying the memory. The fourth line transfers the request content to the request sent to the web server at one time. After the HTTP request is constructed, the request is placed at the end of the current series, and the process function ends. # Principles behind an asynchronous server Obviously, the process function is only part of the proxy logic. We also need to handle the HTTP response returned from the web server and generates the response for the browser. In the example of echo server, we populate the response page directly without network communication. However, in the proxy server, we have to wait for the response from the web server. Of course, we can occupy the thread of this process function and wait for the returned result, but this synchronous waiting mode is obviously not desirable. Thus, it is better that we reply to the user's request asynchronously after receiving the results for the request, and no thread is occupied while we are waiting for the result. Therefore, we set a context for the current series in the head of the process, which contains the proxy\_task itself. In this way, we can populate the results asynchronously. ~~~cpp struct tutorial_series_context { std::string url; WFHttpTask *proxy_task; bool is_keep_alive; }; void process(WFHttpTask *proxy_task) { SeriesWork *series = series_of(proxy_task); ... tutorial_series_context *context = new tutorial_series_context; context->url = req->get_request_uri(); context->proxy_task = proxy_task; series->set_context(context); series->set_callback([](const SeriesWork *series) { delete (tutorial_series_context *)series->get_context(); }); ... } ~~~ In the previous client example, we said that any running task is in a series, and the server task is no exception. Thus, we can get the current series and set the context. In which the URL is mainly used for the subsequent logs, and the proxy\_task is the main content, which is used for resp later. Next, Let’s see how to handle the responses from the web server. ~~~cpp void http_callback(WFHttpTask *task) { int state = task->get_state(); auto *resp = task->get_resp(); SeriesWork *series = series_of(task); tutorial_series_context *context = (tutorial_series_context *)series->get_context(); auto *proxy_resp = context->proxy_task->get_resp(); ... if (state == WFT_STATE_SUCCESS) { const void *body; size_t len; /* set a callback for getting reply status. */ context->proxy_task->set_callback(reply_callback); /* Copy the remote webserver's response, to proxy response. */ resp->get_parsed_body(&body, &len); resp->append_output_body_nocopy(body, len); *proxy_resp = std::move(*resp); ... } else { // return a "404 Not found" page ... } } ~~~ Here we focus on the successful cases only. If the proxy gets a complete HTTP page from the web server, no matter what the return code is, it is considered a success. All failure will simply return a 404 page. Because the data returned to the user may be very large, the maximum size is set to 200MB in this example. Therefore, unlike the previous examples, we need to check the success/failure status of the reply. The type of an HTTP server task is identical to the type of an HTTP client task created by ourselves. Both are WFHttpTask. The difference is that a server task is created by the framework, and its callback is initially empty. The callback of a server task is the same as that of a client. Both are called after an HTTP interaction is completed. Therefore, for all server tasks, the callback is called after the reply is completed. The following three lines of code are explained before. They transfer the response packets from the web server to the proxy response packets without copying. After the **http\_callback** function is ended, the reply to the browser is sent out. Everything is done asynchronously. The remaining function **reply\_callback()** is used just to print some logs here. The proxy task will be automatically deleted after this callback is finished. Finally, the context is destroyed in the callback of the series. # Timing of a server reply Please note that the reply message is sent automatically after all other tasks in the series are finished, so there is no **task->reply()** interface. However, there is a **task->noreply()**. If this interface is called for the server task, the connection will be closed directly at the original reply time. But the callback will still be called (its state is NOREPLY). In the callback of a server task, you can also call **series\_of()** to get the series of that server task. Then, you can still add new tasks to this series, although the reply has finished. workflow-0.11.8/docs/en/tutorial-06-parallel_wget.md000066400000000000000000000126211476003635400222470ustar00rootroot00000000000000# A simple parallel wget: parallel\_wget # Sample code [tutorial-06-parallel\_wget.cc](/tutorial/tutorial-06-parallel_wget.cc) # About parallel\_wget It is our first example on parallel tasks. The program reads multiple HTTP URLs (separated by spaces) from the command line, crawls these URLs in parallel, and prints the crawled results to the standard output according to the input order. # Creating a parallel task In the previous example, you have already learned the SeriesWork class. * SeriesWork consists of a series of tasks that are executed sequentially. The series finishes when all its tasks finish. * ParallelWork class, corresponding to the SeriesWork, consists of multiple series that are executed in parallel. The parallel work finishes when all its series finish. * ParallelWork is a task. According to the above definition, you can generate any complex workflow dynamically or statically. The Workflow class has two interfaces for generating parallel tasks: ~~~cpp class Workflow { ... public: static ParallelWork * create_parallel_work(parallel_callback_t callback); static ParallelWork * create_parallel_work(SeriesWork *const all_series[], size_t n, parallel_callback_t callback); ... }; ~~~ The first interface creates an empty parallel task, and the second interface creates parallel tasks with a series array. Before you start the parallel work, you can use **add\_series()** interface of the ParallelWork to add series to the parallel tasks generated by either interface. In the sample code, we create an empty parallel task and then add the series one by one. ~~~cpp int main(int argc, char *argv[]) { ParallelWork *pwork = Workflow::create_parallel_work(callback); SeriesWork *series; WFHttpTask *task; HttpRequest *req; tutorial_series_context *ctx; int i; for (i = 1; i < argc; i++) { std::string url(argv[i]); ... task = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX, [](WFHttpTask *task) { // store resp to ctx. }); req = task->get_req(); // add some headers. ... ctx = new tutorial_series_context; ctx->url = std::move(url); series = Workflow::create_series_work(task, nullptr); series->set_context(ctx); pwork->add_series(series); } ... } ~~~ You can see that we first create an HTTP task in the code, but the HTTP task cannot be directly added to the parallel task, so we need to use it to create a series first. Each series has its own context, which is used to save the URL and the crawled results. You can learn related methods in our previous examples. # Saving and using the crawled results The callback of an HTTP task is a simple lambda function, which saves the crawled result in its own series context, so that it can be retrieved by the parallel task. ~~~cpp task = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX, [](WFHttpTask *task) { tutorial_series_context *ctx = (tutorial_series_context *)series_of(task)->get_context(); ctx->state = task->get_state(); ctx->error = task->get_error(); ctx->resp = std::move(*task->get_resp()); }); ~~~ This is necessary, because HTTP tasks will be recycled after the callback, so we have to use **std::move()** to move the resp. In the callback of parallel tasks, we can easily get the results: ~~~cpp void callback(const ParallelWork *pwork) { tutorial_series_context *ctx; const void *body; size_t size; size_t i; for (i = 0; i < pwork->size(); i++) { ctx = (tutorial_series_context *)pwork->series_at(i)->get_context(); printf("%s\n", ctx->url.c_str()); if (ctx->state == WFT_STATE_SUCCESS) { ctx->resp.get_parsed_body(&body, &size); printf("%zu%s\n", size, ctx->resp.is_chunked() ? " chunked" : ""); fwrite(body, 1, size, stdout); printf("\n"); } else printf("ERROR! state = %d, error = %d\n", ctx->state, ctx->error); delete ctx; } } ~~~ Here, you can see the two new interfaces of ParallelWork, **size()** and **series\_at(i)**, which are used to obtain the number of the series in parallel and the ith parallel series respectively. You can use **series->get\_context()** to get the context of the series and print out the results.The printing order must be the same as with the order you add the series into the work. In this example, there is no other work after the parallel tasks finish. As we said above, ParallelWork is a kind of tasks, so you can use **series\_of()** to get its series and add a new task. However, if the crawled results are used in the new task, you need to use **std::move()** to move the data to the context of the series of that parallel task. # Starting a parallel task As a parallel task is a kind of tasks, so there is nothing special in starting a parallel task. You can call **start()** directly, or you can use it to build or start a series. In this example, we start a series, wake up the main process in the callback of this series, and exit the program normally. We can also wake up the main process in the callback of parallel tasks, and there is little difference in the program behaviors. However, it is more formal to wake up the main process in the callback of the series. workflow-0.11.8/docs/en/tutorial-07-sort_task.md000066400000000000000000000157071476003635400214470ustar00rootroot00000000000000# Using the built-in algorithm factory: sort\_task # Sample code [tutorial-07-sort\_task.cc](/tutorial/tutorial-07-sort_task.cc) # About sort\_task The program reads a number n from the command line, sorts the random n positive integers in ascending order, and then sorts the results in descending order. You can add the second parameter "p” to the program, and then it can be sorted in parallel. For example: ./sort\_task 100000000 p The above command will sort 100 million integers in ascending order and then in descending order. The two sortings are done in parallel respectively. # About computing tasks Computing tasks (or thread tasks) is a very important function in the framework. When you use the task flow, it is not recommended to directly perform very complicated computation in the callback. All the computations that consume a lot of CPU time can be encapsulated into computing tasks and handed over to the system for scheduling. There is no difference in the usage between computing tasks and networking tasks. The algorithm factory of the system provides some common computing tasks, such as sorting, merging and so on. You can also easily define your own computing tasks. # Creating sorting tasks in ascending order ~~~cpp int main(int argc, char *argv[]) { ... WFSortTask *task; if (use_parallel_sort) task = WFAlgoTaskFactory::create_psort_task("sort", array, end, callback); else task = WFAlgoTaskFactory::create_sort_task("sort", array, end, callback); ... task->start(); ... } ~~~ Unlike WFHttpTask or WFRedisTask, the sorting task has one more template parameter to represent the type of array data to be sorted. **create\_sort\_task** and **create\_psort\_task** produce a common sorting task and a parallel sorting task respectively. Their ****parameters and return values are the same.**** The only thing that needs special explanation is the first parameter "sort", which is the name of the computation queue. It is used to instruct the internal task scheduling. The latter part in this article explains the usage of the queue name. There is no difference in the starting methods and usage between computing tasks and networking tasks. # Handling results Like a networking task, the results are handled in the callback. In this example, the ascending sorting is followed by one descending sorting. ~~~cpp using namespace algorithm; void callback(void SortTask *task) { SortInput *input = task->get_input(); int *first = input->first; int *last = input->last; // print result ... if (task->user_data == NULL) { auto cmp = [](int a1, int a2){ return a2 < a1; }; WFSortTask *reverse; if (use_parallel_sort) reverse = WFAlgoTaskFactory::create_psort_task("sort", first, last, cmp, callback); else reverse = WFAlgoTaskFactory::create_sort_task("sort", first, last, cmp, callback); reverse->user_data = (void *)1; /* as a flag */ series_of(task)->push_back(reverse); } else { // all done. Signal main thread to exit. ... } } ~~~ You can use **get\_input ()** interface of a computing task to get the input data, and use **get\_output ()** to get the output data. For sorting tasks, the input and output are of the same type, and the content are exactly the same. [WFAlgoTaskFactory.h](/src/factory/WFAlgoTaskFactory.h) contains the definitions of the input and output of sorting tasks. ~~~cpp namespace algorithm { template struct SortInput { T *first; T *last; }; template using SortOutput = SortInput; } template using WFSortTask = WFThreadTask, algorithm::SortOutput>; template using sort_callback_t = std::function *)>; ~~~ Obviously, the first and last in the input or output mean the head pointer and the tail pointer of the array to be sorted. Next, we will create a descending sorting task. In this case, we need to pass in a comparison function. ~~~cpp auto cmp = [](int a1, int a2)->bool{ return a2 < a1; }; reverse = WFAlgoTaskFactory::create_sort_task("sort", first, last, cmp, callback); ~~~ Our usage differs slightly from **std::sort()**. Our first and last are pointers, not iterators. Similarly, you can use **create\_psort\_task()** to create a parallel sorting task. And the use of series in the sorting task is no different from that in the networking task. # About the configuration of the computing threads If you don't make any configuration, the calculation scheduler will set the number of threads as the number of the CPU cores in the machine. You can change the value with the following method: ~~~cpp #include "workflow/WFGlobal.h" int main() { struct WFGlobalSettings settings = GLOBAL_SETTINGS_DEFAULT; settings.compute_threads = 16; WORKFLOW_library_init(&settings); ... } ~~~ With the above configuration, the system will create 16 threads for computations. # About the parallel sorting algorithm The built-in parallel sorting algorithm use block+two-way merge. Its space complexity is O(1). The algorithm uses globally configured computing threads for computation, but at most 128 threads can be used. Because no extra space is used, the speedup ratio will be smaller than the number of threads, and the average CPU usage will be smaller. For the detailed implementation, please see [WFAlgoTaskFactory.inl](/src/factory/WFAlgoTaskFactory.inl). # About the name of a calculation task queue The computing task does not have priority levels. The only thing that can affect the scheduling order is the queue name of a computing task. In this example, the queue name is a string "sort". To name a queue is very simple. Please note the following items: * The queue name is a static string, and new queue names cannot be generated infinitely. For example, you cannot generate the queue name according to the request id, because each queue is allocated a small block of resources internally. * If the computing threads are not 100% occupied, all tasks are started in real time, and the queue names have no effect. * If there are multiple computing steps in a service flow and they are interspersed among multiple network communications, you can simply give each calculation step a name, which is better than using one name as a whole. * If all computing tasks use the same name, the scheduling order of all tasks is consistent with the order of submission, which will affect the average response time in some scenarios. * If each kind of computing task has an independent name, it means that they are scheduled fairly. And the same kind of tasks are scheduled sequentially, the practical effect is better. * In a word, unless the computing load of the machine is already very heavy, you do not need to pay special attention to the queue name and you can just give each kind of task a name. workflow-0.11.8/docs/en/tutorial-08-matrix_multiply.md000066400000000000000000000171041476003635400226730ustar00rootroot00000000000000# User-defined computing tasks: matrix\_multiply # Sample code [tutorial-08-matrix\_multiply.cc](/tutorial/tutorial-08-matrix_multiply.cc) # About matrix\_multiply The program multiplies two matrices and prints the results on the screen. The main purpose of the example is to show how to implement a user-defined CPU computing task. # About computing tasks You need to provide three types of basic information when you define a computer task: INPUT, OUTPUT, and routine. INPUT and OUTPUT are two template parameters, which can be of any type. routine means the process from INPUT to OUTPUT, which is defined as follows: ~~~cpp template class __WFThreadTask { ... std::function routine; ... }; ~~~ It can be seen that routine is a simple computing process from INPUT to OUTPUT. The INPUT pointer is not necessarily be const, but you can also pass the function of const INPUT \*. For example, to implement an adding task, you can: ~~~cpp struct add_input { int x; int y; }; struct add_ouput { int res; }; void add_routine(const add_input *input, add_output *output) { output->res = input->x + input->y; } typedef WFThreadTask add_task; ~~~ In the example of matrix multiplication, the input is two matrices and the output is one matrix. They are defined as follows: ~~~cpp namespace algorithm { using Matrix = std::vector>; struct MMInput { Matrix a; Matrix b; }; struct MMOutput { int error; size_t m, n, k; Matrix c; }; void matrix_multiply(const MMInput *in, MMOutput *out) { ... } } ~~~ As the input matrices may be illegal in matrix multiplication, so there is an error field in the output to indicate errors. # Generating computing tasks After you define the types of input and output and the algorithm process, you can use WFThreadTaskFactory to generate a computing task. In [WFTaskFactory.h](/src/factory/WFTaskFactory.h), the computing task factory is defined as follows: ~~~cpp template class WFThreadTaskFactory { private: using T = WFThreadTask; public: static T *create_thread_task(const std::string& queue_name, std::function routine, std::function callback); static T *create_thread_task(time_t seconds, long nanoseconds, const std::string& queue_name, std::function routine, std::function callback); ... }; ~~~ There are two interfaces for creating tasks here. The second interface supports the user to pass in the task running time limit, we will introduce this function in the next section. Slightly different from the previous network factory class or the algorithm factory class, this factory requires two template parameters: INPUT and OUTPUT. queue\_name is explained in the previous example. routine is the computation process, and callback means the callback. In our example, we see this call: ~~~cpp using MMTask = WFThreadTask; using namespace algorithm; int main() { typedef WFThreadTaskFactory MMFactory; MMTask *task = MMFactory::create_thread_task("matrix_multiply_task", matrix_multiply, callback); MMInput *input = task->get_input(); input->a = {{1, 2, 3}, {4, 5, 6}}; input->b = {{7, 8}, {9, 10}, {11, 12}}; ... } ~~~ After the task is generated, use **get\_input()** interface to get the pointer of the input data. This is similar to the **get\_req()** in a network task. The start and the end of a task is the same as those of a network task. Similarly, the callback is very simple: ~~~cpp void callback(MMTask *task) // MMtask = WFThreadTask { MMInput *input = task->get_input(); MMOutput *output = task->get_output(); assert(task->get_state() == WFT_STATE_SUCCESS); if (output->error) printf("Error: %d %s\n", output->error, strerror(output->error)); else { printf("Matrix A\n"); print_matrix(input->a, output->m, output->k); printf("Matrix B\n"); print_matrix(input->b, output->k, output->n); printf("Matrix A * Matrix B =>\n"); print_matrix(output->c, output->m, output->n); } } ~~~ You can ignore the the possibility of failure in the ordinary computing tasks, and the end state is always SUCCESS. The callback simply prints out the input and the output. If the input data are illegal, the error will be printed out. # Computing task with running time limit Obviously, our framework can not interrupt a computing task because it's a user function, and the users have to make sure the function will terminate normally. But we support users to create a computing task with a running time limit, and if the task doesn't finish within this time, the task will callback directly: ~~~cpp template class WFThreadTaskFactory { private: using T = WFThreadTask; public: static T *create_thread_task(time_t seconds, long nanoseconds, const std::string& queue_name, std::function routine, std::function callback); ... }; ~~~ This create_thread_task function needs to pass two more parameters, seconds and nanoseconds. If the running time of func reaches the seconds+nanosconds time limit, the task callback directly, and the state is WFT_STATE_SYS_ERROR and the error is ETIMEDOUT. But the task routine will continue to run till the end. # Symmetry of the algorithm and the protocol In our system, algorithms and protocols are highly symmetrical on a very abstract level. There are thread tasks with user-defined algorithms, obviously there are network tasks with user-defined protocols. A user-defined algorithm requires the user to provide the algorithm procedure, and a user-defined protocol requires the user to provide the procedure of serialization and deserialization. You can see an introduction in [Simple client/server based on user-defined protocols](/tutorial-10-user_defined_protocol.md) For the user-defined algorithms and the user-defined protocols, both must be very pure . For example, an algorithm is just a conversion procedure from INPUT to OUPUT, and the algorithm does not know the existence of task, series, etc. The implementation of an HTTP protocol only cares about serialization and deserialization, and does not need to care about the task definition. Instead, the HTTP protocol is referred to in an http task. # Composite features of thread tasks and network tasks In this example, we use WFThreadTaskFactory to build a thread task. This is the simplest way to get a computing task, and it is sufficient in most cases. Similarly, you can simply define a server and a client with a user-defined protocol. However, in the previous example, we can use the algorithm factory to generate a parallel sorting task, which is obviously not possible with a routine. For a network task, such as a Kafka task, interactions with several machines may be required to get results, but it is completely transparent to users. Therefore, our tasks are composite. If you use our framework skillfully, you can design many composite components. workflow-0.11.8/docs/en/tutorial-09-http_file_server.md000066400000000000000000000206441476003635400230000ustar00rootroot00000000000000# Http server with file IO: http\_file\_server # Sample code [tutorial-09-http\_file\_server.cc](/tutorial/tutorial-09-http_file_server.cc) # About http\_file\_server http\_file\_server is a web server. You can start a web server after specifying the startup port and the root path (the default setting is the current path). You can also specify a certificate file and a key file in PEM format to start an HTTPS web server. User may access the server through command line, the request will be sent to IP address 127.0.0.1. The program mainly demonstrates how to use disk IO tasks. In the Linux system, we use the aio interface in the kernel of Linux, and the file reading is completely asynchronous. # Starting a server For starting a server, the steps are almost the same as those when starting an echo server or an HTTP proxy. There is one more way to start an SSL server here: ~~~cpp class WFServerBase { ... int start(unsigned short port, const char *cert_file, const char *key_file); ... }; ~~~ In other words, you can specify a cert file and a key file in PEM format to start an SSL server. In addition, when you define a server, you can use **std::bind()** to bind a root parameter to the process. The root parameter means the root path of the service. ~~~cpp void process(WFHttpTask *server_task, const char *root) { ... } int main(int argc, char *argv[]) { ... const char *root = (argc >= 3 ? argv[2] : "."); auto&& proc = std::bind(process, std::placeholders::_1, root); WFHttpServer server(proc); // start server ... } ~~~ # Handling requests Similar to http\_proxy, no threads are occupied in file reading. Instead, an asynchronous task is generated to read files, and a reply to the request is generated after the reading is completed. Please note again that the complete reply data should be read into the memory before the reply message is sent. Therefore, it is not suitable for transferring very large files. ~~~cpp void process(WFHttpTask *server_task, const char *root) { // generate abs path. ... int fd = open(abs_path.c_str(), O_RDONLY); if (fd >= 0) { size_t size = lseek(fd, 0, SEEK_END); void *buf = malloc(size); /* As an example, assert(buf != NULL); */ WFFileIOTask *pread_task; pread_task = WFTaskFactory::create_pread_task(fd, buf, size, 0, pread_callback); /* To implement a more complicated server, please use series' context * instead of tasks' user_data to pass/store internal data. */ pread_task->user_data = resp; /* pass resp pointer to pread task. */ server_task->user_data = buf; /* to free() in callback() */ server_task->set_callback([](WFHttpTask *t){ free(t->user_data); }); series_of(server_task)->push_back(pread_task); } else { resp->set_status_code("404"); resp->append_output_body("404 Not Found."); } } ~~~ Unlike http\_proxy that generates a new HTTP client task, here a pread task is generated by the factory. [WFAlgoTaskFactory.h](/src/factory/WFTaskFactory.h) contains the definitions of relevant interfaces. ~~~cpp struct FileIOArgs { int fd; void *buf; size_t count; off_t offset; }; ... using WFFileIOTask = WFFileTask; using fio_callback_t = std::function; ... class WFTaskFactory { public: ... static WFFileIOTask *create_pread_task(int fd, void *buf, size_t count, off_t offset, fio_callback_t callback); static WFFileIOTask *create_pwrite_task(int fd, void *buf, size_t count, off_t offset, fio_callback_t callback); /* Interface with file path name */ static WFFileIOTask *create_pread_task(const std::string& path, void *buf, size_t count, off_t offset, fio_callback_t callback); static WFFileIOTask *create_pwrite_task(const std::string& path, void *buf, size_t count, off_t offset, fio_callback_t callback); }; ~~~ Both pread and pwrite return WFFileIOTask. We do not distinguish between sort and psort, and we do not distinguish between client and server task. They all follow the same principle. In addition to these two interfaces, preadv and pwritev return WFFileVIOTask; fsync and fdsync return WFFileSyncTask. You can see the details in the header file. The example uses the user\_data field of the task to save the global data of the service. For larger services, we recommend to use series context. You can see the [proxy examples](/tutorial/tutorial-05-http_proxy.cc) for details. # Handling file reading results ~~~cpp using namespace protocol; void pread_callback(WFFileIOTask *task) { FileIOArgs *args = task->get_args(); long ret = task->get_retval(); HttpResponse *resp = (HttpResponse *)task->user_data; /* close fd only when you created File IO task with **fd** interface. */ close(args->fd); if (ret < 0) { resp->set_status_code("503"); resp->append_output_body("503 Internal Server Error."); } else /* Use '_nocopy' carefully. */ resp->append_output_body_nocopy(args->buf, ret); } ~~~ Use **get\_args()** of the file task to get the input parameters. Here it is a FileIOArgs struct, and it's **fd** field will be -1 if the task was created with **pathname**. Use **get\_retval()** to get the return value of the operation. If ret < 0, the task fails. Otherwise, the ret is the size of the read data. In the file task, ret < 0 and task->get\_state()! = WFT\_STATE\_SUCCESS are completely equivalent. The memory of the buf domain is managed by ourselves. You can use **append\_output\_body\_nocopy()** to pass that memory to resp. After the reply is completed, we will **free()** this block of memory with this line in the process: server\_task->set\_callback(\[](WFHttpTask \*t){ free(t->user\_data); }); # Interact with the server through command line After the server is started, users may access it through command line. Simply input the file name that you want to get, or input Ctrl-D to end the program. The repeating process is implemnted by using WFRepeaterTask, which can be created by this factory function: ~~~cpp using repeated_create_t = std::function; using repeater_callback_t = std::function; class WFTaskFactory { WFRpeaterTask *create_repeater_task(repeated_create_t create, repeater_callback_t callback); }; ~~~ As above, a repeater task is created with a task creator function. The repeater calls the task creator repeatedly and run the task until the creator return a NULL pointer. When the using's input is not empty, our creator will create an HTTP task on IP 127.0.0.1 to access the server. ~~~cpp { auto&& create = [&scheme, port](WFRepeaterTask *)->SubTask *{ ... scanf("%1023s", buf); if (*buf == '\0') return NULL; std::string url = scheme + "127.0.0.1:" + std::to_string(port) + "/" + buf; WFHttpTask *task = WFTaskFactory::create_http_task(url, 0, 0, [](WFHttpTask *task) { ... }); return task; }; WFFacilities::WaitGroup wg(1); WFRepeaterTask *repeater; repeater = WFTaskFactory::create_repeater_task(create, [&wg](WFRepeaterTask *) { wg.done(); }); repeater->start(); wg.wait(); server.stop(); } ~~~ Finally, when the creator returned NULL, the repeater's callback is called and the program will be ended. # About the implementation of the file IO Linux operating system supports a set of asynchronous IO system calls with high efficiency and very little CPU occupation. If you use our framework in a Linux system, this set of interfaces are used by default. We have implemented a set of posix aio interfaces to support other UNIX systems, and used the sigevent notification method of threads, but it is no longer in use because of its low efficiency. Currently, for non-Linux systems, asynchronous IO is always simulated by multi-threading. When an IO task arrives, a thread is created in real time to execute IO tasks, and then a callback is used to return to the handler thread pool. Multi-threaded IO is also the only choice in macOS, because macOS does not have good sigevent support and posix aio will not work in macOS. Some UNIX systems do not support fdatasync. In this case, an fdsync task is equivalent to an fsync task. workflow-0.11.8/docs/en/tutorial-10-user_defined_protocol.md000066400000000000000000000331271476003635400240010ustar00rootroot00000000000000# A simple user-defined protocol: client/server # Sample codes [message.h](/tutorial/tutorial-10-user_defined_protocol/message.h) [message.cc](/tutorial/tutorial-10-user_defined_protocol/message.cc) [server.cc](/tutorial/tutorial-10-user_defined_protocol/server.cc) [client.cc](/tutorial/tutorial-10-user_defined_protocol/client.cc) # About user\_defined\_protocol This example designs a simple communication protocol, and builds a server and a client on that protocol. The server converts the message sent by client into uppercase and returns it to the client. # Protocol format The protocol message contains one 4-byte head and one message body. Head is an integer in network byte order, indicating the length of body. The formats of the request messages and the response messages are identical. # Protocol implementation A user-defined protocol should provide its own serialization and deserialization methods, which are virtual functions in ProtocolMeessage class. In addition, for the convenience of use, we strongly recommend users to implement the **move constructor** and **move assignment** for messages (for std::move ()). [ProtocolMessage.h](/src/protocol/ProtocolMessage.h) contains the following serialization and deserialization interfaces: ~~~cpp namespace protocol { class ProtocolMessage : public CommMessageOut, public CommMessageIn { private: virtual int encode(struct iovec vectors[], int max); /* You have to implement one of the 'append' functions, and the first one * with arguement 'size_t *size' is recommmended. */ virtual int append(const void *buf, size_t *size); virtual int append(const void *buf, size_t size); ... }; } ~~~ ### Serialization function: encode * The encode function is called before the message is sent, and it is called only once for each message. * In the encode function, you need to serialize the message into a vector array, and the number of array elements must not exceed max. Current the value of max is 2048. * For the definition of **struct iovec**, please see the system calls **readv** or **writev**. * Normally the return value of the encode function is between 0 and max, indicating how many vector are used in the message. * In case of UDP protocol, please note that the total length must not be more than 64k, and no more than 1024 vectors are used (in Linux, writev writes only 1024 vectors at one time). * The encode -1 indicates errors. To return -1, you need to set errno. If the return value is > max, you will get an EOVERFLOW error. All errors are obtained in the callback. * For performance reasons, the content pointed to by the iov\_base pointer in the vector will not be copied. So it generally points to the member of the message class. ### Deserialization function: append * The append function is called every time a data block is received. Therefore, for each message, it may be called multiple times. * buf and size are the content and the length of received data block respectively. You need to move the data content. * If the interface **append(const void \*buf, size\_t \*size)** is implemented, you can tell the framework how much length is consumed at this time by modifying \* size. remaining size = received size - consumed size, and the remaining part of the buf will be received again when the append is called next time. This function is more convenient for protocol parsing. Of course, you can also move the whole content and manage it by yourself. In this case, you do not need to modify \*size. * If the **append** function returns 0, it indicates that the message is incomplete and the transmission continues. The return value of 1 indicates the end of the message. -1 indicates errors, and you need to set errno. * In a word, the append function is used to tell the framework whether the message transmission is completed or not. Please don't perform complicated and unnecessary protocol parsing in the append. ### Setting the errno * If encode or append returns -1 or other negative numbers, it should be interpreted as failure, and you should set the errno to pass the error reason. You can obtain this error in the callback. * If the system calls or the library functions such as libc fail (for example, malloc), libc will definitely set errno, and you do not need to set it again. * Some errors of illegal messages are quite common. For example, EBADMSG or EMSGSIZE can be used to indicate that the message content is wrong and the message is too large respectively. * You can use a value that exceeds the errno range defined in the system to indicate a user-defined error. Generally, you can use a value greater than 256. * Please do not use a negative errno. Because negative numbers are used inside the framework to indicate SSL errors. In our example, the serialization and deserialization of messages are very simple. The header file [message.h](/tutorial/tutorial-10-user_defined_protocol/message.h) declares the request class and the response class. ~~~cpp namespace protocol { class TutorialMessage : public ProtocolMessage { private: virtual int encode(struct iovec vectors[], int max); virtual int append(const void *buf, size_t size); ... }; using TutorialRequest = TutorialMessage; using TutorialResponse = TutorialMessage; } ~~~ Both the request class and the response class belong to the same type of messages. You can directly introduce them with using. Note that both the request and the response can be constructed without parameters. In other words, you must provide a constructor without parameters or no constructor. In addition, the response object may be destroyed and reconstruct during communication if retrial occurs, therefore it should be a RAII class, otherwise things will be complicated). [message.cc](/tutorial/tutorial-10-user_defined_protocol/message.cc) contains the implementation of encode and append: ~~~cpp namespace protocol { int TutorialMessage::encode(struct iovec vectors[], int max/*max==8192*/) { uint32_t n = htonl(this->body_size); memcpy(this->head, &n, 4); vectors[0].iov_base = this->head; vectors[0].iov_len = 4; vectors[1].iov_base = this->body; vectors[1].iov_len = this->body_size; return 2; /* return the number of vectors used, no more then max. */ } int TutorialMessage::append(const void *buf, size_t size) { if (this->head_received < 4) { size_t head_left; void *p; p = &this->head[this->head_received]; head_left = 4 - this->head_received; if (size < 4 - this->head_received) { memcpy(p, buf, size); this->head_received += size; return 0; } memcpy(p, buf, head_left); size -= head_left; buf = (const char *)buf + head_left; p = this->head; this->body_size = ntohl(*(uint32_t *)p); if (this->body_size > this->size_limit) { errno = EMSGSIZE; return -1; } this->body = (char *)malloc(this->body_size); if (!this->body) return -1; this->body_received = 0; } size_t body_left = this->body_size - this->body_received; if (size > body_left) { errno = EBADMSG; return -1; } memcpy(this->body, buf, size); if (size < body_left) return 0; return 1; } } ~~~ The implementation of encode is very simple, in which two vectors are always, pointing to the head and the body respectively. Note that the iov\_base pointer must point to a member of the message class. When you use append, you should ensure that the 4-byte head is received completely before reading the message body. Moreover, we can't guarantee that the first append must contain a complete head, so the process is a little cumbersome. The append implements the size\_limit function, and an EMSGSIZE error will be returned if the size\_limit is exceeded. You can ignore the size_limit field if you don't need to limit the message size. Because we require the communication protocol is two way with a request and a response, users do not need to consider the so-called "TCP packet sticking" problem. The problem should be treated as an error message directly. Now, with the definition and implementation of messages, we can build a server and a client. # Server and client definitions With the request and response classes, we can build a server and a client based on this protocol. The previous example explains the type definitions related to an HTTP protocol: ~~~cpp using WFHttpTask = WFNetworkTask; using http_callback_t = std::function; using WFHttpServer = WFServer; using http_process_t = std::function; ~~~ Similarly, for the protocol in this tutorial, there is no difference in the definitions of data types: ~~~cpp using WFTutorialTask = WFNetworkTask; using tutorial_callback_t = std::function; using WFTutorialServer = WFServer; using tutorial_process_t = std::function; ~~~ # server There is no difference between this server and an ordinary HTTP server. We give priority to IPv6 startup, which does not affect the client requests in IPv4. In addition, the maximum request size is limited to 4KB. Please see [server.cc](/tutorial/tutorial-10-user_defined_protocol/server.cc) for the complete code. # client The logic of the client is to receive the user input from standard IO, construct a request, send it to the server and get the results. Here we use WFRepeaterTask to implement the repeating process, terminates if the user's input is empty. For the sake of security, we limit the packet size of the server reply to 4KB. The only thing that a client needs to know is how to generate a client task on a user-defined protocol. There are three interface options in [WFTaskFactory.h](/src/factory/WFTaskFactory.h): ~~~cpp template class WFNetworkTaskFactory { private: using T = WFNetworkTask; public: static T *create_client_task(TransportType type, const std::string& host, unsigned short port, int retry_max, std::function callback); static T *create_client_task(TransportType type, const std::string& url, int retry_max, std::function callback); static T *create_client_task(TransportType type, const ParsedURI& uri, int retry_max, std::function callback); static T *create_client_task(TransportType type, const struct sockaddr *addr, socklen_t addrlen, int retry_max, std::function callback); ... }; ~~~ Among them, TransportType specifies the transport layer protocol, and the current options include TT\_TCP, TT\_UDP, TT\_SCTP, TT\_TCP\_SSL and TT\_SCTP\_SSL. There is little difference between the interfaces. In our example, the URL is not needed for the time being. We use a domain name and a port to create a task. The actual code is shown as follows. We inherited the WFTaskFactory class, but this derivation is not required. ~~~cpp using namespace protocol; class MyFactory : public WFTaskFactory { public: static WFTutorialTask *create_tutorial_task(const std::string& host, unsigned short port, int retry_max, tutorial_callback_t callback) { using NTF = WFNetworkTaskFactory; WFTutorialTask *task = NTF::create_client_task(TT_TCP, host, port, retry_max, std::move(callback)); task->set_keep_alive(30 * 1000); return task; } }; ~~~ You can see that we used the WFNetworkTaskFactory\ class to create a client task. Next, by calling the **set\_keep\_alive()** interface of the task, the connection is kept for 30 seconds after the communication is completed. Otherwise, the short connection will be used by default. The previous examples have explained the knowledge in other codes of the above client. Please see [client.cc](/tutorial/tutorial-10-user_defined_protocol/client.cc). # How is the request on an built-in protocol generated Currently, there are five built-in protocols in the framework: HTTP, Redis, MySQL, Kafka and DNS. Can we generate an HTTP or Redis task in the same way? For example: ~~~cpp WFHttpTask *task = WFNetworkTaskFactory::create_client_task(...); ~~~ Please note that an HTTP task generated in this way will lose a lot of functions. For example, it is impossible to identify whether to use persistent connection according to the header, and it is impossible to identify redirection, etc. Similarly, if a MySQL task is generated in this way, it may not run at all, because there is no login authentication process. A Kafka request may need to have complicated interactions with multiple brokers, so the request created in this way obviously cannot complete this process. This shows that the generation of one message in each built-in protocol is far more complicated than that in this example. Similarly, if you need to implement a communication protocol with more functions, there are still many codes to write. workflow-0.11.8/docs/en/tutorial-11-graph_task.md000066400000000000000000000065631476003635400215540ustar00rootroot00000000000000# Direct Acyclic Graph (DAG):graph_task # Sample code [tutorial-11-graph_task.cc](/tutorial/tutorial-11-graph_task.cc) # About graph_task The graph_task example demonstrates how to implement more complex inter-task dependencies by building a DAG. # Create tasks in the DAG In this tutorial, we create a timer task, two http fetching task, and a 'go' task. Timer task executes a delay of 1 second before fetching, http tasks fetch the home page of 'sogou' and 'baidu' in parallel, and after all of that, go task will print the fetching result. The dependencies of the tasks are: ~~~ +-------+ +---->| Http1 |-----+ | +-------+ | +-------+ +-v--+ | Timer | | Go | +-------+ +-^--+ | +-------+ | +---->| Http2 |-----+ +-------+ ~~~ # Create the graph task Graph is a kind of task as well. We can create a graph task by this function: ~~~cpp class WFTaskFactory { public: static WFGraphTask *create_graph_task(graph_callback_t callback); ... }; ~~~ The graph is a empty graph after it's creation. Of course you may run an empty graph and will get to it callback immediately. # Create graph nodes We'v got 4 orindary tasks, which can not been added to the graph directly but need to be turned into graph nodes: ~~~cpp { /* Create graph nodes */ WFGraphNode& a = graph->create_graph_node(timer); WFGraphNode& b = graph->create_graph_node(http_task1); WFGraphNode& c = graph->create_graph_node(http_task2); WFGraphNode& d = graph->create_graph_node(go_task); } ~~~ The ``create_graph_node`` interface of WFGraphTask creates a graph node that refers to a task. And we can use the references of graph nodes to specify the dependencies of them. Otherwise, they are all standalone nodes, and will run in parallel when the graph task is started. # Build the graph By using the '-->' or '<--' operators, we can specify the dependencies: ~~~cpp { /* Build the graph */ a-->b; a-->c; b-->d; c-->d; } ~~~ And now we'v built the graph that we described. And we can use it like an orindary task. Also, any of the following codes is legal and equivalent: ~~~cpp { a-->b-->d; a-->c-->d; } ~~~ ~~~cpp { d<--b<--a; d<--c<--a; } ~~~ ~~~cpp { d<--b<--a-->c-->d; } ~~~ # Canceling successors In graph tasks, we extend SeriesWork's **cancel** operation. When the series of a graph node is canceled, the operation will apply on all it's successive nodes recursively. The **cancel** operation is usually used in a task's callback: ~~~cpp int main() { WFGraphTask *graph = WFTaskFactory::create_graph_task(graph_callback); WFHttpTask *task = WFTaskFactory::create_http_task(url, 0, 0, [](WFHttpTask *t){ if (t->get_state() != WFT_STATE_SUCCESS) series_of(t)->cancel(); }); WFGraphNode& a = graph->create_graph_node(task); WFGraphNode& b = ...; WFGraphNode& c = ...; WFGraphNode& d = ...; a-->b-->c; b-->d; graph->start(); ... } ~~~ In this case, when http task failed, nodes b, c, d will all be canceled, because the operation is recursive. # Data passing Because the tasks in a graph don't share a same series, there is no general method for passing data between graph nodes. # Acknowledgement Some designs are inspired by [taskflow](https://github.com/taskflow/taskflow). workflow-0.11.8/docs/en/tutorial-12-mysql_cli.md000066400000000000000000000417251476003635400214250ustar00rootroot00000000000000# Asynchronous MySQL client: mysql\_cli # Sample code [tutorial-12-mysql\_cli.cc](/tutorial/tutorial-12-mysql_cli.cc) # About mysql\_cli The usage of mysql\_cli in the tutorial is similar to that of the official client. It is an asynchronous MySQL client with an interactive command line interface. To start the program, run the command: ./mysql_cli \ After startup, you can directly enter MySQL command in the terminal to interact with db, or enter `quit` or `Ctrl-C` to exit. # Format of MySQL URL mysql://username:password@host:port/dbname?character\_set=charset&character\_set\_results=charset - set scheme to be **mysqls://** for accessing MySQL with SSL connnection (MySQL server 5.7 or above is required). - fill in the username and the password for the MySQL database; Special characters in password need to be escaped. ~~~cpp // Password: @@@@#### std::string url = "mysql://root:" + StringUtil::url_encode_component("@@@@####") + "@127.0.0.1"; ~~~ - the default port number is 3306; - **dbname** is the name of the database to be used. It is recommended to provide a dbname if SQL statements only operates on one database; - If you have upstream selection requirements for MySQL, please see [upstream documents](/docs/en/about-upstream.md). - **character_set** indicates a character set used for the client, with the same meaning of --default-character-set in official client. The default value is utf8. For details, please see official MySQL documents [character-set.html](https://dev.mysql.com/doc/internals/en/character-set.html). - **character_set_results** indicates a character set for client, connection and results. If you wants to use `SET NAMES ` in SQL statements, please set it here. Sample MySQL URL: mysql://root:password@127.0.0.1 mysql://@test.mysql.com:3306/db1?character\_set=utf8&character_set_results=utf8 mysqls://localhost/db1?character\_set=big5 # Creating and starting a MySQL task You can use WFTaskFactory to create a MySQL task. The usage of creating interface and callback functions are similar to other tasks in workflow: ~~~cpp using mysql_callback_t = std::function; WFMySQLTask *create_mysql_task(const std::string& url, int retry_max, mysql_callback_t callback); void set_query(const std::string& query); ~~~ You can call **set\_query()** on the request to write SQL statements after creating a WFMySQLTask. If **set_query()** had **NOT** been called before the task started, the user might get **WFT_ERR_MYSQL_QUERY_NOT_SET** in callback. Other functions, including callback, series and user\_data are used in a way similar to other tasks in workflow. The following codes show some general usage: ~~~cpp int main(int argc, char *argv[]) { ... WFMySQLTask *task = WFTaskFactory::create_mysql_task(url, RETRY_MAX, mysql_callback); task->get_req()->set_query("SHOW TABLES;"); ... task->start(); ... } ~~~ # Supported commands Currently the supported command is **COM\_QUERY**, which can cover the basic requirements for adding, deleting, modifying and querying data, creating and deleting databases, creating and deleting tables, prepare, using stored procedures and using transactions. Because the program doesn't support the selection of databases (**USE** command) in our interactive commands, if there are **cross-database** operations in SQL statements, you can specify the database and table with **db\_name.table\_name**. Any other command can be **spliced together** and then passed to WFMySQLTask with `set_query()`. (including INSERT/UPDATE/SELECT/PREPARE/CALL) The spliced commands will be executed sequentially until an error occurs, and the previous commands will be executed successfully. For example: ~~~cpp req->set_query("SELECT * FROM table1; CALL procedure1(); INSERT INTO table3 (id) VALUES (1);"); ~~~ # Parsing results Similar to other tasks in workflow, you can use **task->get\_resp()** to get **MySQLResponse**. For details on the interfaces, please see [MySQLResult.h](/src/protocol/MySQLResult.h). One request will get one response, which is a 3-dimensional structure. - one response consists of one or more result sets; - the type of each result set may be **MYSQL_STATUS_GET_RESULT** or **MYSQL_STATUS_OK**; - one result set of type **MYSQL_STATUS_GET_RESULT** consists of one ore more rows; - one row consists of one or more fields, or data cells; The two types of result sets can be judged by ``cursor->get_cursor_status()``. | |MYSQL_STATUS_GET_RESULT|MYSQL_STATUS_OK| |------|-----------------------|---------------| |SQL command|SELECT(including each SELECT in PROCEDURE)|INSERT / UPDATE / DELETE / ...| |Semantics|Read. One result set consists of a 2-dimensional structure
reprecenting the response of one read operation.|Write. One result set reprecents the results of
one write operation.| |main APIs|fetch_fields();
fetch_row(&row_arr);
...|get_insert_id();
get_affected_rows();
...| When errors occur in spliced commands, you may first get the multiple result sets through **MySQLResultCursor**, the commands which have been executed successfully. Then determine whether ``resp->get_packet_type()`` equals to **MYSQL_PACKET_ERROR** and get the specific error information through ``resp->get_error_code()`` and ``resp->get_error_msg()``. A **PROCEDURE** command containing N **SELECT** statements will return N result sets of **MYSQL_STATUS_GET_RESULT** and 1 result set of **MYSQL_STATUS_OK**. The user ignores this **MYSQL_STATUS_OK** result set is fine. To get all the data, the specific steps should be: 1. checking the task state (state at communication): you can check whether the task is successfully executed by checking whether ``task->get_state()`` is equal to **WFT_STATE_SUCCESS**; 2. determining the type of the response packet (state at parsing the return packet): call ``resp->get_packet_type()`` to check the type of the last SQL query return packet. The common types include: - MYSQL\_PACKET\_OK: parsed successfully, should use cursor to get all the result sets. - MYSQL\_PACKET\_EOF: parsed successfully, should use cursor to get all the result sets. - MYSQL\_PACKET\_ERROR: requests: failed or partial failed, may use cursor to get the result sets of those successful commands. 3. traverse the result sets: you can use **MySQLResultCursor** to read the content in the result set. Because the data returned by a MySQL server contains multiple result sets, the cursor will **automatically point to the reading position of the first result set** at first. 4. checking the result set state (state at reading the result sets): **cursor->get_cursor_status()** returns the following states: - MYSQL\_STATUS\_GET\_RESULT: current result set is a READ result set; - MYSQL\_STATUS\_END: the last record of the current READ result set has been read; - MYSQL\_STATUS\_OK: current result set is a WRITE result set; - MYSQL\_STATUS\_ERROR: parsing error; 5. reading the basic content of **MYSQL_STATUS_OK** result set: - ``unsigned long long get_affected_rows() const;`` - ``unsigned long long get_insert_id() const;`` - ``int get_warnings() const;`` - ``std::string get_info() const;`` 6. reading each field and each columns of **MYSQL_STATUS_GET_RESULT** result set: - `int get_field_count() const;` - `const MySQLField *fetch_field();` - `const MySQLField *const *fetch_fields() const;` 7. reading each line of **MYSQL_STATUS_GET_RESULT** result set: you can use ``cursor->fetch_row()`` to read by row until the return value is false, in which the offset within the cursor that points to the row in the current result set will be moved: - `int get_rows_count() const;` - `bool fetch_row(std::vector& row_arr);` - `bool fetch_row(std::map& row_map);` - `bool fetch_row(std::unordered_map& row_map);` - `bool fetch_row_nocopy(const void **data, size_t *len, int *data_type);` 8. taking out all the rows in the current **MYSQL_STATUS_GET_RESULT** result set directly: you can use ``cursor->fetch_all()`` to read all rows, and the cursor that is used to record the rows internally will be moved directly to the end; The cursor state changes to **MYSQL_STATUS_END**: - `bool fetch_all(std::vector>& rows);` 9. returning to the head of the current **MYSQL_STATUS_GET_RESULT** result set: if it is necessary to read this result set again, you can use ``cursor->rewind()`` to return to the head of the current result set, and then read it via the operations in Step 7 or Step 8; 10. getting the next result set: because the data packet returned by MySQL server may contains multiple result sets (for example, each SELECT/INSERT/... statement gets one result set; or the multiple result sets returned by calling a PROCEDURE). Therefore, you can use ``cursor->next_result_set()`` to jump to the next result set. If the return value is false, it means that all result sets have been taken. 11. returning to the first result set: use **cursor->first\_result\_set()** to return to the heads of all result sets, and then you can repeat the operations from Step 4. 12. getting the data of each column (MySQLCell): the row read in Step 5 is composed of multiple columns, and the result of each column is one MySQLCell. It mainly uses the following interfaces: - `int get_data_type();` returns MYSQL\_TYPE\_LONG, MYSQL\_TYPE\_STRING, and etc. For the details, please see [mysql\_types.h](/src/protocol/mysql_types.h). - `bool is_TYPE() const;` the TYPE is int, string or ulonglong. It is used to check the data type. - `TYPE as_TYPE() const;` same as the above. It reads the data from MySQLCell in a certain type. - `void get_cell_nocopy(const void **data, size_t *len, int *data_type) const;` nocopy interface. The whole example is shown below: ~~~cpp void task_callback(WFMySQLTask *task) { // step-1. Check the status of the task if (task->get_state() != WFT_STATE_SUCCESS) { fprintf(stderr, "task error = %d\n", task->get_error()); return; } MySQLResultCursor cursor(task->get_resp()); bool test_first_result_set_flag = false; bool test_rewind_flag = false; // step-2. Check other status of repsponse packet if (resp->get_packet_type() == MYSQL_PACKET_ERROR) { fprintf(stderr, "ERROR. error_code=%d %s\n", task->get_resp()->get_error_code(), task->get_resp()->get_error_msg().c_str()); } begin: // step-3. Traverse the result sets do { // step-4. Check the status of the result set if (cursor.get_cursor_status() == MYSQL_STATUS_OK) { // step-5. Read the basic content of MYSQL_STATUS_OK result set fprintf(stderr, "OK. %llu rows affected. %d warnings. insert_id=%llu.\n", cursor.get_affected_rows(), cursor.get_warnings(), cursor.get_insert_id()); } else if (cursor.get_cursor_status() == MYSQL_STATUS_GET_RESULT) { fprintf(stderr, "field_count=%u rows_count=%u ", cursor.get_field_count(), cursor.get_rows_count()); // step-6. Read each fields. This is a nocopy api const MySQLField *const *fields = cursor.fetch_fields(); for (int i = 0; i < cursor.get_field_count(); i++) { fprintf(stderr, "db=%s table=%s name[%s] type[%s]\n", fields[i]->get_db().c_str(), fields[i]->get_table().c_str(), fields[i]->get_name().c_str(), datatype2str(fields[i]->get_data_type())); } // step-8. Read all the rows. You may use while (cursor.fetch_row(map/vector)) to get each rows accoding to step-7 std::vector> rows; cursor.fetch_all(rows); for (unsigned int j = 0; j < rows.size(); j++) { // step-12. Read each cell for (unsigned int i = 0; i < rows[j].size(); i++) { fprintf(stderr, "[%s][%s]", fields[i]->get_name().c_str(), datatype2str(rows[j][i].get_data_type())); // step-12. Check the type wih is_string()and transform the type with as_string() if (rows[j][i].is_string()) { std::string res = rows[j][i].as_string(); fprintf(stderr, "[%s]\n", res.c_str()); } else if (rows[j][i].is_int()) { fprintf(stderr, "[%d]\n", rows[j][i].as_int()); } // else if ... } } } // step-10. Get the next result set } while (cursor.next_result_set()); if (test_first_result_set_flag == false) { test_first_result_set_flag = true; // step-11. Go back to the first result set cursor.first_result_set(); goto begin; } if (test_rewind_flag == false) { test_rewind_flag = true; // step-9. Go back to the first position of the current result set cursor.rewind(); goto begin; } return; } ~~~ # WFMySQLConnection Since it is a highly concurrent asynchronous client, this means that the client may have more than one connection to the server. As both MySQL transactions and preparation are stateful, in order to ensure that one transaction or preparation ocupies one connection exclusively, you can use our encapsulated secondary factory WFMySQLConnection to create a task. Each WFMySQLConnection guarantees that one connection is occupied exclusively. For the details, please see [WFMySQLConnection.h](/src/client/WFMySQLConnection.h). ### 1\. Creating and initializing WFMySQLConnection When creating a WFMySQLConnection, you need to pass in **id**, and the subsequent calls on this WFMySQLConnection will use this id and url to find the corresponding unique connection. When initializing a WFMySQLConnection, you need to pass a URL, and then you do not need to set the URL for the task created on this connection. ~~~cpp class WFMySQLConnection { public: WFMySQLConnection(int id); int init(const std::string& url); ... }; ~~~ ### 2\. Creating a task and closing a connection With **create\_query\_task()**, you can create a task by writing an SQL request and a callback function. The task is garuanteed to be sent on this connection. Sometimes you need to close this connection manually. Because when you stop using it, this connection will be kept until MySQL server time out. During this period, if you use the same id and url to create a WFMySQLConnection, you may reuse the connection. Therefore, we suggest that if you do not want to reuse the connection, you should use **create\_disconnect\_task()** to create a task and manually close the connection. ~~~cpp class WFMySQLConnection { public: ... WFMySQLTask *create_query_task(const std::string& query, mysql_callback_t callback); WFMySQLTask *create_disconnect_task(mysql_callback_t callback); } ~~~ WFMySQLConnection is equivalent to a secondary factory. In the framework, we arrange that the life cycle of any factory object does not need to be maintained until the task ends. The following code is completely legal: ~~~cpp WFMySQLConnection *conn = new WFMySQLConnection(1234); conn->init(url); auto *task = conn->create_query_task("SELECT * from table", my_callback); conn->deinit(); delete conn; task->start(); ~~~ ### 3\. Cautions Do not generate new connection id infinitely, becauase easy id will occupy a little memory. The connection will be put to an internal connection pool if user didn't run a disconnect task, and will been reused by another WFMySQLConnection object initialized with the same id and url. Running tasks on the same connection parallelly will fail with error **EAGAIN**. If you have started `BEGIN` but have not `COMMIT` or `ROLLBACK` during the transaction and the connection has been interrupted during the transaction, the connection will be automatically reconnected internally by the framework, and you will get **ECONNRESET** error in the next task request. In this case, the transaction statements those have not been `COMMIT` would be expired and you may need to send them again. ### 4\. Preparation You can also use the WFMySQLConnection for **PREPARE**. And you can easily use it to **defend against SQL injection**. If the connection is reconnected, you also get an **ECONNRESET** error. ### 5\. Complete example ~~~cpp WFMySQLConnection conn(1); conn.init("mysql://root@127.0.0.1/test"); // test transaction const char *query = "BEGIN;"; WFMySQLTask *t1 = conn.create_query_task(query, task_callback); query = "SELECT * FROM check_tiny FOR UPDATE;"; WFMySQLTask *t2 = conn.create_query_task(query, task_callback); query = "INSERT INTO check_tiny VALUES (8);"; WFMySQLTask *t3 = conn.create_query_task(query, task_callback); query = "COMMIT;"; WFMySQLTask *t4 = conn.create_query_task(query, task_callback); WFMySQLTask *t5 = conn.create_disconnect_task(task_callback); SeriesWork *series = create_series_work(t1, nullptr); *series << t2 << t3 << t4 << t5; series->start(); ~~~ workflow-0.11.8/docs/en/tutorial-13-kafka_cli.md000066400000000000000000000304641476003635400213340ustar00rootroot00000000000000# Asynchronous Kafka Client: kafka_cli # Sample Codes [tutorial-13-kafka_cli.cc](/tutorial/tutorial-13-kafka_cli.cc) # About Compiler Because of supporting multiple compression methods of Kafka, [zlib](https://github.com/madler/zlib.git), [snappy](https://github.com/google/snappy.git), [lz4(>=1.7.5)](https://github.com/lz4/lz4.git), [zstd](https://github.com/facebook/zstd.git) and other third-party libraries are used in the compression algorithms in the Kafka protocol, and they must be installed before the compilation. It supports both CMake and Bazel for compiling. CMake: You can use **make KAFKA=y** to compile a separate library for Kafka protocol(libwfkafka.a和libwfkafka.so) and use **cd tutorial; make KAFKA=y** to compile kafka_cli. Bazel: You can use **bazel build kafka** to compile a separate library for Kafka protocol and use **bazel build kafka_cli** to compile kafka_cli. # About kafka_cli Kafka_cli is a kafka client for producing and fetching messages in Kafka. When you compile the source codes, type the command **make KAFKA=y** in the **tutorial** directory or type the command **make KAFKA=y tutorial** in the root directory of the project. The program then reads kafka broker server addresses and the current task type (produce/fetch) from the command line: ./kafka_cli \ [p/c] The program exists automatically after all the tasks are completed, and all the resources will be completedly freed. In the command, the broker_url may contain several urls seperated by comma(,). - For instance, kafka://host:port,kafka://host1:port... or: **kafkas**://host:port,**kafkas**://host1:port for kafka over SSL; - The default port is 9092 for TCP and 9093 for SSL; - Do not mix 'kafkas://' with "kafka://", otherwise the init function will fail with errno EINVAL; - If you want to use upstream policy at this layer, please refer to [upstream documents](/docs/en/about-upstream.md). The following are several Kafka broker_url samples: kafka://127.0.0.1/ kafka://kafka.host:9090/ kafka://10.160.23.23:9000,10.123.23.23,kafka://kafka.sogou kafkas://broker1.kafka.sogou,kafkas://broker2.kafka.sogou Illegal broker_url sample (The first one is SSL, and the second one is not): kafkas://broker1.kafka.sogou,broker2.kafka.sogou # Principles and Features Kafka client has no third-party dependencies internally except for the libraries used in the compression. With the high performance of Workflow, When properly configured and in fair environments, tens of thousands of Kafka requests can be processed in one second. Internally, a Kafka client divides each request into parallel tasks according to the brokers used. In parallel tasks, there is one sub-task for each broker address. In this way, the efficiency is maximized. Besides, the connection reuse mechanism in the Workflow ensures that the total number of connections is kept within a reasonable range. If there are multiple topic partitions under one broker address, you may create multiple clients and then create and start separate tasks for each topic partition to increase the throughput. # Creating and Starting Kafka Tasks To create and start a Kafka task, create a **WFKafkaClient** first and then call **init** to initialize that **WFKafkaClient**. ~~~cpp int init(const std::string& broker_url); int init(const std::string& broker_url, const std::string& group); ~~~ In the above code snippet, **broker_url** means the address of the kafka broker cluster. Its format is the same as the broker_url in the above section. **group** means the group_name of a consumer group, which is used for the consumer group in a fetch task. In the case of produce tasks or fetch tasks without any consumer groups, do not use this interface. For a consumer group, you can specify the heartbeat interval in milliseconds to keep the heartbeats. ~~~cpp void set_heartbeat_interval(size_t interval_ms); ~~~ Then you can create a Kafka task with that **WFKafkaClient**. ~~~cpp using kafka_callback_t = std::function; WFKafkaTask *create_kafka_task(const std::string& query, int retry_max, kafka_callback_t cb); WFKafkaTask *create_kafka_task(int retry_max, kafka_callback_t cb); ~~~ In the above code snippet, **query** includes the type of the task, the topic and other properties. **retry_max** means the maximum number of retries. **cb** is the user-defined callback function, which will be called after the task is completed. You can also change the default settings of the task to meet the requirements. For details, refer to [KafkaDataTypes.h](/src/protocol/KafkaDataTypes.h). ~~~cpp KafkaConfig config; config.set_client_id("workflow"); task->set_config(std::move(config)); ~~~ The supported configuration items are described below: Item name | Type | Default value | Description ------ | ---- | -------| ------- produce_timeout | int | 100ms | Maximum time for produce. produce_msg_max_bytes | int | 1000000 bytes | Maximum length for one message. produce_msgset_cnt | int | 10000 | Maximun numbers of messges in one communication set produce_msgset_max_bytes | int | 1000000 bytes | Maximum length of messages in one communication. fetch_timeout | int | 100ms | Maximum timeout for fetch. fetch_min_bytes | int | 1 byte | Minimum length of messages in one fetch communication. fetch_max_bytes | int | 50M bytes | Maximum length of messages in one fetch communication. fetch_msg_max_bytes | int | 1M bytes | Maximum length of one single message in a fetch communication. offset_timestamp | long long int | -1 | Initialized offfset in the consumer group mode when there is no offset history. -2 means the oldest offset; -1 means the latest offset. session_timeout | int | 10s | Maximum initialization timeout for joining a consumer group. rebalance_timeout | int | 10s | Maximum timeout for synchronizing a consumer group information after a client joins the consumer group. produce_acks | int | -1 | Number of brokers to ensure the successful replication of a message before the return of a produce task. -1 indicates all replica brokers. allow_auto_topic_creation | bool | true | Flag for controlling whether a topic is created automatically for the produce task if it does not exist. broker_version | char * | NULL | Version number for brokers, which should be manually specified when the version number is smaller than 0.10. compress_type | int | NoCompress | Compression type for produce messages. client_id | char * | NULL | Identifier of a client. check_crcs | bool | false | Flag for controlling whether to check crc32 in the messages for a fetch task. offset_store | int | 0 | When joining the consumer group, whether to use the last submission offset, 1 means to use the specified offset, and 0 means to use the last submission preferentially. sasl_mechanisms | char * | NULL | Sasl certification type, currently only supports plain, and is on the ongoing development of sasl support. sasl_username | char * | NULL | Username required for sasl authentication. sasl_password | char * | NULL | Password required for sasl authentication. After configuring all the parameters, you can call **start** interface to start the Kafka task. # About Produce Tasks 1\. After you create and initialize a **WFKafkaClient**, you can specify the topic or other information in the **query** to create **WFKafkaTask** tasks. For example: ~~~cpp int main(int argc, char *argv[]) { ... client = new WFKafkaClient(); client->init(url); task = client->create_kafka_task("api=fetch&topic=xxx&topic=yyy", 3, kafka_callback); ... task->start(); ... } ~~~ 2\. After the **WFKafkaTask** is created, call **set_key**, **set_value**, **add_header_pair** and other methods to build a **KafkaRecord**. For information about more interfaces on **KafkaRecord**, refer to [KafkaDataTypes.h](/src/protocol/KafkaDataTypes.h). Then you can call **add_produce_record** to add a **KafkaRecord**. For the detailed definitions of the interfaces, refer to [WFKafkaClient.h](/src/client/WFKafkaClient.h). The second parameter **partition** in **add_produce_record**, >=0 means the specified **partition**; -1 means that the **partition** is chosen randomly or the user-defined **kafka_partitioner_t** is used. For **kafka_partitioner_t**, you can call the **set_partitioner** interface to specify the user-defined rules. For example: ~~~cpp int main(int argc, char *argv[]) { ... WFKafkaClient *client_fetch = new WFKafkaClient(); client_fetch->init(url); task = client_fetch->create_kafka_task("api=produce&topic=xxx&topic=yyy", 3, kafka_callback); task->set_partitioner(partitioner); KafkaRecord record; record.set_key("key1", strlen("key1")); record.set_value(buf, sizeof(buf)); record.add_header_pair("hk1", 3, "hv1", 3); task->add_produce_record("workflow_test1", -1, std::move(record)); ... task->start(); ... } ~~~ 3\. You can use one of the four compressions supported by Kafka in the produce task by configuration. For example: ~~~cpp int main(int argc, char *argv[]) { ... WFKafkaClient *client_fetch = new WFKafkaClient(); client_fetch->init(url); task = client_fetch->create_kafka_task("api=produce&topic=xxx&topic=yyy", 3, kafka_callback); KafkaConfig config; config.set_compress_type(Kafka_Zstd); task->set_config(std::move(config)); KafkaRecord record; record.set_key("key1", strlen("key1")); record.set_value(buf, sizeof(buf)); record.add_header_pair("hk1", 3, "hv1", 3); task->add_produce_record("workflow_test1", -1, std::move(record)); ... task->start(); ... } ~~~ # About Fetch Tasks You may use consumer group mode or manual mode for fetch tasks. 1\. Manual mode In this mode, you do not need to specify consumer groups, but you must specify topic, partition and offset. For example: ~~~cpp client = new WFKafkaClient(); client->init(url); task = client->create_kafka_task("api=fetch", 3, kafka_callback); KafkaToppar toppar; toppar.set_topic_partition("workflow_test1", 0); toppar.set_offset(0); task->add_toppar(toppar); ~~~ 2\. Consumer group mode In this mode, you must specify the name of the consumer group when initializing a client. For example: ~~~cpp int main(int argc, char *argv[]) { ... WFKafkaClient *client_fetch = new WFKafkaClient(); client_fetch->init(url, cgroup_name); task = client_fetch->create_kafka_task("api=fetch&topic=xxx&topic=yyy", 3, kafka_callback); ... task->start(); ... } ~~~ 3\. Committing offset In the consumer group mode, after a message is consumed, you can create a commit task in the callback to automatically submit the consumption record. For example: ~~~cpp void kafka_callback(WFKafkaTask *task) { ... commit_task = client.create_kafka_task("api=commit", 3, kafka_callback); ... commit_task->start(); ... } ~~~ # Closing the Client In the consumer group mode, before you close a client, you must call **create_leavegroup_task** to create a **leavegroup_task**. This task will send a **leavegroup** packet. If no **leavegroup_task** is started, the group does not know that the client is leaving and will trigger rebalance. # Processing Kafka Results The data structure of the message result set is KafkaResult, and you can call **get_result()** in the **WFKafkaTask** to retrieve the results. Then you can call the **fetch_record** in the **KafkaResult** to retrieve all records of the task. The record is a two-dimensional vector. The first dimension is topic partition, and the second dimension is the **KafkaRecord** under that topic partition. [KafkaResult.h](/src/protocol/KafkaResult.h) contains the definition of **KafkaResult**. ~~~cpp void kafka_callback(WFKafkaTask *task) { int state = task->get_state(); int error = task->get_error(); // handle error states ... protocol::KafkaResult *result = task->get_result(); result->fetch_records(records); for (auto &v : records) { for (auto &w: v) { const void *value; size_t value_len; w->get_value(&value, &value_len); printf("produce\ttopic: %s, partition: %d, status: %d, offset: %lld, val_len: %zu\n", w->get_topic(), w->get_partition(), w->get_status(), w->get_offset(), value_len); } } ... protocol::KafkaResult new_result = std::move(*task->get_result()); if (new_result.fetch_records(records)) { for (auto &v : records) { if (v.empty()) continue; for (auto &w: v) { if (fp) { const void *value; size_t value_len; w->get_value(&value, &value_len); fwrite(w->get_value(), w->get_value_len(), 1, fp); } } } } ... } ~~~ workflow-0.11.8/docs/en/xmake.md000066400000000000000000000025431476003635400164500ustar00rootroot00000000000000# xmake compiling ``` // compile workflow library xmake // compile test xmake -g test // run test xmake run -g test // compile tutorial xmake -g tutorial // compile benchmark xmake -g benchmark ``` ## running `xmake run -h` can see which targets you can run Select a target to run, for instance: ``` xmake run tutorial-06-parallel_wget ``` ## xmake install ``` sudo xmake install ``` ## Compile static / shared library ``` // compile static lib xmake f -k static xmake -r ``` ``` // compile shard lib xmake f -k shared xmake -r ``` `tips : -r means -rebuild` ## build options `xmake f --help` can see our defined options. ``` Command options (Project Configuration): --workflow_inc=WORKFLOW_INC workflow inc (default: /media/psf/pro/workflow/_include) --upstream=[y|n] build upstream component (default: y) --consul=[y|n] build consul component --workflow_lib=WORKFLOW_LIB workflow lib (default: /media/psf/pro/workflow/_lib) --redis=[y|n] build redis component (default: y) --kafka=[y|n] build kafka component --mysql=[y|n] build mysql component (default: y) ``` You can cut or integrate each components with the following commands ``` xmake f --redis=n --kafka=y --mysql=n xmake -r ``` workflow-0.11.8/docs/tutorial-01-wget.md000066400000000000000000000106451476003635400177700ustar00rootroot00000000000000# 创建第一个任务:wget # 示例代码 [tutorial-01-wget.cc](/tutorial/tutorial-01-wget.cc) # 关于wget 程序从stdin读取http/https URL,抓取网页并把内容打印到stdout,并将请求和响应的http header打印在stderr。 为了简单起见,程序用Ctrl-C退出,但会保证所有资源先被完全释放。 # 创建并启动http任务 ~~~cpp WFHttpTask *task = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX, wget_callback); protocol::HttpRequest *req = task->get_req(); req->add_header_pair("Accept", "*/*"); req->add_header_pair("User-Agent", "Wget/1.14 (gnu-linux)"); req->add_header_pair("Connection", "close"); task->start(); pause(); ~~~ WFTaskFactory::create_http_task()产生一个http任务,在[WFTaskFactory.h](../src/factory/WFTaskFactory.h)文件里,原型定义如下: ~~~cpp WFHttpTask *create_http_task(const std::string& url, int redirect_max, int retry_max, http_callback_t callback); ~~~ 前几个参数不用过多解释,http_callback_t是http任务的callback,定义如下: ~~~cpp using http_callback_t = std::function; ~~~ 说白了,就是一个参数为Task本身,没有返回值的函数。这个callback可以传NULL,表示无需callback。我们一切任务的callback都是这个风格。 需要说明的是,所有工厂函数不会返回失败,所以不用担心task为空指针,哪怕是url不合法。一切错误都在callback再处理。 task->get_req()函数得到任务的request,默认是GET方法,HTTP/1.1,长连接。框架会自动加上request_uri,Host等。 框架会在发送前根据需要自动加上Content-Length或Connection这些http header。用户也可以通过add_header_pair()方法添加自己的header。 关于http消息的更多接口,可以在[HttpMessage.h](../src/protocol/HttpMessage.h)中查看。 task->start()启动任务,非阻塞,并且不会失败。之后callback必然会在被调用。因为异步的原因,start()以后显然不能再用task指针了。 为了让示例尽量简单,start()之后调用pause()防止程序退出,用户需要Ctrl-C结束程序。 # 处理http抓取结果 在这个示例中,我们使用一个普遍的函数处理结果。当然,std::function支持更多的功能。 ~~~cpp void wget_callback(WFHttpTask *task) { protocol::HttpRequest *req = task->get_req(); protocol::HttpResponse *resp = task->get_resp(); int state = task->get_state(); int error = task->get_error(); // handle error states ... std::string name; std::string value; // print request to stderr fprintf(stderr, "%s %s %s\r\n", req->get_method(), req->get_http_version(), req->get_request_uri()); protocol::HttpHeaderCursor req_cursor(req); while (req_cursor.next(name, value)) fprintf(stderr, "%s: %s\r\n", name.c_str(), value.c_str()); fprintf(stderr, "\r\n"); // print response header to stderr ... // print response body to stdin const void *body; size_t body_len; resp->get_parsed_body(&body, &body_len); // always success. fwrite(body, 1, body_len, stdout); fflush(stdout); } ~~~ 在这个callback里,task就是我们通过工厂产生的task。 task->get_state()与task->get_error()分别获得任务的运行状态和错误码。我们先略过错误处理的部分。 task->get_resp()得到任务的response,这个和request区别不大,都是HttpMessage的派生。 之后通过HttpHeaderCursor对象,对request和response的header进行扫描。在[HttpUtil.h](../src/protocol/HttpUtil.h)可以看到Cursor的定义。 ~~~cpp class HttpHeaderCursor { public: HttpHeaderCursor(const HttpMessage *message); ... void rewind(); ... bool next(std::string& name, std::string& value); bool find(const std::string& name, std::string& value); ... }; ~~~ 相信这个cursor在使用上应该不会有什么疑惑。 之后一行resp->get_parsed_body()获得response的http body。这个调用在任务成功的状态下,必然返回true,body指向数据区。 这个调用得到的是原始的http body,不解码chunk编码。如需解码chunk编码,可使用[HttpUtil.h](../src/protocol/HttpUtil.h)里的HttpChunkCursor。 另外需要说明的是,find()接口会修改cursor内部的指针,即使用过find()过后如果仍然想对header进行遍历,需要通过rewind()接口回到cursor头部。 workflow-0.11.8/docs/tutorial-02-redis_cli.md000066400000000000000000000141061476003635400207540ustar00rootroot00000000000000# 实现一次redis写入与读出:redis_cli # 示例代码 [tutorial-02-redis_cli.cc](/tutorial/tutorial-02-redis_cli.cc) # 关于redis_cli 程序从命令行读入一个redis服务器地址,以及以一对key,value。执行SET命令写入这对KV,之后再读出验证写入是否成功。 程序运行方法:./redis_cli \ \ \ 为简单起见,程序需要用Ctrl-C结束。 # Redis URL的格式 redis://:password@host:port/dbnum?query#fragment 如果是SSL,则为: rediss://:password@host:port/dbnum?query#fragment password是可选项。port的缺省值是6379,dbnum缺省值0,范围0-15。 query和fragment部分工厂里不作解释,用户可自行定义。比如,用户有upstream选取需求,可以自定义query和fragment。相关内容参考upstream文档。 redis URL示例: redis://127.0.0.1/ redis://:12345678@redis.some-host.com/1 # 创建并启动Redis任务 创建Redis任务与创建http任务并没有什么区别,少了redirect_max参数。 ~~~cpp using redis_callback_t = std::function; WFRedisTask *create_redis_task(const std::string& url, int retry_max, redis_callback_t callback); ~~~ 在这个示例里,我们想在redis task里存一些用户信息,包括url和key,以便在callback里使用。 当然,我们可利用std::function来绑定参数,但在这里我们利用了task里的void *user_data指针。这是task的一个public成员。 ~~~cpp struct tutorial_task_data { std::sring url; std::string key; }; ... struct tutorial_task_data data; data.url = argv[1]; data.key = argv[2]; WFRedisTask *task = WFTaskFactory::create_redis_task(data.url, RETRY_MAX, redis_callback); protocol::RedisRequest *req = task->get_req(); req->set_request("SET", { data.key, argv[3] }); task->user_data = &data; task->start(); pause(); ~~~ 与http task的get_req()类似,redis task的get_req()返回任务对应的redis request。 RedisRequest提供的功能可以在[RedisMessage.h](../src/protocol/RedisMessage.h)查看。 其中,set_request接口用于设置redis命令。 ~~~cpp void set_request(const std::string& command, const std::vector& params); ~~~ 相信经常使用redis的人,对这个接口不会有什么疑问。但必须注意,我们的请求是禁止SELECT命令和AUTH命令的。 因为用户每次请求并不能指定具体连接,SELECT之后下一次请求并不能保证在同一个连接上发起,那么这个命令对用户来讲没有任何意义。 对数据库选择和密码的指定,请在redis URL里完成。并且,必须是每次请求的URL都带着这些信息。 另外,我们的redis client是支持cluster模式的,可以自动处理MOVED和ASK回复并重定向。用户不能自己发送ASKING命令。 # 处理请求结果 程序在SET命令成功之后,再发起一次GET命令,验证写入的结果。GET命令也用同一个callback。所以,函数里会判断这是哪个命令的结果。 同样,我们先忽略错误处理部分。 ~~~cpp void redis_callback(WFRedisTask *task) { protocol::RedisRequest *req = task->get_req(); protocol::RedisResponse *resp = task->get_resp(); int state = task->get_state(); int error = task->get_error(); protocol::RedisValue val; ... resp->get_result(val); std::string cmd; req->get_command(cmd); if (cmd == "SET") { tutorial_task_data *data = (tutorial_task_data *)task->user_data; WFRedisTask *next = WFTaskFactory::create_redis_task(data->url, RETRY_MAX, redis_callback); next->get_req()->set_request("GET", { data->key }); series_of(task)->push_back(next); fprintf(stderr, "Redis SET request success. Trying to GET...\n"); } else /* if (cmd == 'GET') */ { // print the GET result ... fprintf(stderr, "Finished. Press Ctrl-C to exit.\n"); } } ~~~ RedisValue是一次redis request得到的结果,同样在[RedisMessage.h](../src/protocol/RedisMessage.h)里可以看到其接口。 callback需要特别解释的,是series_of(task)->push_back(next)这个语句。因为这是我们第一次使用到Workflow的功能。 在这里next是我们下一个要发起的redis task,执行GET操作。我们并不是执行next->start()来启动任务,而是把next任务push_back到当前任务序列的末尾。 这两种方法的区别在于: * 用start来启动任务,任务是被立刻启动的,而push_back的方法,next任务是在callback结束之后被启动。 * 最起码的好处是,push_back方法可以保证log打印不会乱。否则,用next->start()的话,示例中"Finished."这个log可能会被先打印。 * 用start来启动下一个任务的话,当前任务序列(series)就结束了,next任务会新启动一个新的series。 * series是可以设置callback的,虽然在示例中没有用到。 * 在并行任务里,series是并行任务的一个分枝,series结束就会认为分枝结束。并行相关内容在后续教程中讲解。 总之,如果你想在一个任务之后启动下一个任务,一般是使用push_back操作来完成(还有些情况可能要用到push_front)。 而series_of()则是一个非常重要的调用,是一个不属于任何类的全局函数。其定义和实现在[Workflow.h](../src/factory/Workflow.h#L140)里: ~~~cpp static inline SeriesWork *series_of(const SubTask *task) { return (SeriesWork *)task->get_pointer(); } ~~~ 任何task都是SubTask类型的派生。而任何运行中的task,一定属于某个series。通过series_of调用,得到了任务所在的series。 而push_back是SeriesWork类的一个调用,其功能是将一个task放到series的末尾。类似调用还有push_front。本示例中,用哪个调用并没有区别。 ~~~cpp class SeriesWork { ... public: void push_back(SubTask *task); void push_front(SubTask *task); ... } ~~~ SeriesWork类在我们整个体系中,扮演重要的角色。在下一个示例中,我们将展现SeriesWork更多的功能。 workflow-0.11.8/docs/tutorial-03-wget_to_redis.md000066400000000000000000000065031476003635400216600ustar00rootroot00000000000000# 任务序列的更多功能:wget_to_redis # 示例代码 [tutorial-03-wget_to_redis.cc](/tutorial/tutorial-03-wget_to_redis.cc) # 关于wget_to_redis 程序从命令行读入一条http URL和一条redis URL,把抓取下来的Http页面(以http URL为key)存入redis。 与之前两个示例不同,我们加入唤醒机制,让程序可以自动退出,无需Ctrl-C。 # 创建Http任务并设置参数 和上一个示例类似,本示例也是串行执行两个请求。最大的区别是,我们要通知主线程任务已经执行结束,并正常退出。 另外,我们多加入两个调用,限制一下http抓取返回内容的大小,以及接收回复的最大时间。 ~~~cpp WFHttpTask *http_task = WFTaskFactory::create_http_task(...); ... http_task->get_resp()->set_size_limit(20 * 1024 * 1024); http_task->set_receive_timeout(30 * 1000); ~~~ set_size_limit()是HttpMessage的调用,用于限制接收http消息时包的大小。事实上所有的协议消息都要求提供这个接口。 set_receive_timeout()是接收数据的超时,单位为ms。 上述代码限制http消息不超过20M,完整接收时间不超过30秒。我们还有更多更丰富的超时配置,后述文档中再介绍。 # 创建并启动SeriesWork 之前两组示例中,我们直接调用task->start()启动第一个任务。task->start()操作实际的工作方法是, 先创建一个以task为首任务的SeriesWork,再启动这个series。在[WFTask.h](../src/factory/WFTask.h)里,可以看到start的实现: ~~~cpp template class WFNetWorkTask : public CommRequest { public: void start() { assert(!series_of(this)); Workflow::start_series_work(this, nullptr); } ... }; ~~~ 我们想给series设置一个callback,并加入一些上下文。所以我们不使用任务的start接口,而是自己创建一个series。 SeriesWork不能new,delete,也不能派生。只能通过Workflow::create_series_work()接口产生。在[Workflow.h](../src/factory/Workflow.h)中, 通常是用这个调用: ~~~cpp using series_callback_t = std::function; class Workflow { public: static SeriesWork *create_series_work(SubTask *first, series_callback_t callback); }; ~~~ 在示例代码中,我们的用法是: ~~~cpp struct tutorial_series_context { std::string http_url; std::string redis_url; size_t body_len; bool success; }; ... struct tutorial_series_context context; ... SeriesWork *series = Workflow::create_series_work(http_task, series_callback); series->set_context(&context); series->start(); ~~~ 之前的示例,我们用task里的void *user_data指针保存上下文信息。但这个示例中,我们把上文信息放在series里, 这么做显然更合理一些,series是完整的任务链,所有任务都能得到并修改上下文。 series的callback函数在series所有任务被执行完之后调用,在这里,我们简单的用一个lamda函数,打印运行结果并唤醒主线程。 # 其余的工作 剩下的事情就没有什么特别的了,http抓取成功之后启动一个redis任务写库。如果抓取失败或http body长度为0,则不再启动redis任务。 无论是什么情况,程序都能在所有任务结束之后正常退出,因为任务都在同一个series里。 workflow-0.11.8/docs/tutorial-04-http_echo_server.md000066400000000000000000000157001476003635400223650ustar00rootroot00000000000000# 第一个server:http_echo_server # 示例代码 [tutorial-04-http_echo_server.cc](/tutorial/tutorial-04-http_echo_server.cc) # 关于http_echo_server 这是一个http server,返回一个html页面,显示浏览器发送的http请求的header信息。 程序log里会打印出请求的client地址,请求序号(当前连接上的第几次请求)。当同一连接上完成10次请求,server主动关闭连接。 程序通过Ctrl-C正常结束,一切资源完全回收。 # 创建与启动http server 本示例里,我们采用http server的默认参数。创建和启动过程非常简单。 ~~~cpp WFHttpServer server(process); port = atoi(argv[1]); if (server.start(port) == 0) { pause(); server.stop(); } ... ~~~ 这个过程实在太简单,没有什么好讲。要注意start是非阻塞的,所以要pause住程序。显然你也可以启动多个server对象再pause。 server启动之后,任何时刻都可以通过stop()接口关停server。关停是非暴力式的,会等待正在服务的请求执行完。 所以,stop是一个阻塞操作。如果需要非阻塞的关闭,可使用shutdown+wait_finish接口。 start()接口有好几个重载函数,在[WFServer.h](../src/server/WFServer.h)里,可以看到如下一些接口: ~~~cpp class WFServerBase { public: /* To start TCP server. */ int start(unsigned short port); int start(int family, unsigned short port); int start(const char *host, unsigned short port); int start(int family, const char *host, unsigned short port); int start(const struct sockaddr *bind_addr, socklen_t addrlen); /* To start an SSL server */ int start(unsigned short port, const char *cert_file, const char *key_file); int start(int family, unsigned short port, const char *cert_file, const char *key_file); int start(const char *host, unsigned short port, const char *cert_file, const char *key_file); int start(int family, const char *host, unsigned short port, const char *cert_file, const char *key_file); int start(const struct sockaddr *bind_addr, socklen_t addrlen, const char *cert_file, const char *key_file); /* For graceful restart or multi-process server. */ int serve(int listen_fd); int serve(int listen_fd, const char *cert_file, const char *key_file); /* Get the listening address. Used when started a server on a random port. */ int get_listen_addr(struct sockaddr *addr, socklen_t *addrlen) const; }; ~~~ 这些接口都比较好理解。任何一个start函数,当端口号为0时,将使用随机端口。此时用户可能需要在server启动完成之后通过get_listen_addr获得实际监听地址。 启动SSL server时,cert_file和key_file为PEM格式。 最后两个带listen_fd的serve()接口,主要用于优雅重启。或者简单建立一个非TCP协议(如SCTP)的server。 需要特别提醒一下,我们一个server对象对应一个listen_fd,如果在IPv4和IPv6两个协议上都运行server,需要: ~~~cpp { WFHttpServer server_v4(process); WFHttpServer server_v6(process); server_v4.start(AF_INET, port); server_v6.start(AF_INET6, port); ... // now stop... server_v4.shutdown(); /* shutdown() is nonblocking */ server_v6.shutdown(); server_v4.wait_finish(); server_v6.wait_finish(); } ~~~ 这种方式我们没有办法让两个server共享连接记数。所以推荐只启动IPv6 server,因为IPv6 server可以接受IPv4的连接。 # http echo server的业务逻辑 我们看到在构造http server的时候,传入了一个process参数,这也是一个std::function,定义如下: ~~~cpp using http_process_t = std::function; using WFHttpServer = WFServer; template<> WFHttpServer::WFServer(http_process_t proc) : WFServerBase(&HTTP_SERVER_PARAMS_DEFAULT), process(std::move(proc)) { } ~~~ 其实这个http_proccess_t和的http_callback_t类型是完全一样的。都是处理一个WFHttpTask。 对server来讲,我们的目标就是根据request,填写好response。 同样我们用一个普通函数实现process。逐条读出request的http header写入html页面。 ~~~cpp void process(WFHttpTask *server_task) { protocol::HttpRequest *req = server_task->get_req(); protocol::HttpResponse *resp = server_task->get_resp(); long seq = server_task->get_task_seq(); protocol::HttpHeaderCursor cursor(req); std::string name; std::string value; char buf[8192]; int len; /* Set response message body. */ resp->append_output_body_nocopy("", 6); len = snprintf(buf, 8192, "

%s %s %s

", req->get_method(), req->get_request_uri(), req->get_http_version()); resp->append_output_body(buf, len); while (cursor.next(name, value)) { len = snprintf(buf, 8192, "

%s: %s

", name.c_str(), value.c_str()); resp->append_output_body(buf, len); } resp->append_output_body_nocopy("", 7); /* Set status line if you like. */ resp->set_http_version("HTTP/1.1"); resp->set_status_code("200"); resp->set_reason_phrase("OK"); resp->add_header_pair("Content-Type", "text/html"); resp->add_header_pair("Server", "Sogou WFHttpServer"); if (seq == 9) /* no more than 10 requests on the same connection. */ resp->add_header_pair("Connection", "close"); // print log ... } ~~~ 大多数HttpMessage相关操作之前已经介绍过了,在这里唯一的一个新操作是append_output_body()。 显然让用户生成完整的http body再传给我们并不太高效。用户只需要调用append接口,把离散的数据一块块扩展到message里就可以了。 append_output_body()操作会把数据复制走,另一个带_nocopy后缀的接口会直接引用指针,使用时需要注意不可以指向局部变量。 相关几个调用在[HttpMessage.h](../src/protocol/HttpMessage.h)可以看到其声明: ~~~cpp class HttpMessage { public: bool append_output_body(const void *buf, size_t size); bool append_output_body_nocopy(const void *buf, size_t size); ... bool append_output_body(const std::string& buf); }; ~~~ 再次强调,使用append_output_body_nocopy()接口时,buf指向的数据的生命周期至少需要延续到task的callback。 函数中另外一个变量seq,通过server_task->get_task_seq()得到,表示该请求是当前连接上的第几次请求,从0开始计。 程序中,完成10次请求之后就强行关闭连接,于是: ~~~cpp if (seq == 9) /* no more than 10 requests on the same connection. */ resp->add_header_pair("Connection", "close"); ~~~ 关闭连接还可以通过task->set_keep_alive()接口来完成,但对于http协议,还是推荐使用设置header的方式。 这个示例中,因为返回的页面很小,我们没有关注回复成功与否。下一个示例http_proxy我们将看到如果获得回复的状态。 workflow-0.11.8/docs/tutorial-05-http_proxy.md000066400000000000000000000222441476003635400212440ustar00rootroot00000000000000# 异步server的示例:http_proxy # 示例代码 [tutorial-05-http_proxy.cc](/tutorial/tutorial-05-http_proxy.cc) # 关于http_proxy 这是一个http代理服务器,可以配置在浏览器里使用。支持所有的http method。 因为https代理的原理不同,这个示例并不支持https代理,你只能浏览http网站。 这个proxy在实现上需要抓取下来完整的http页面再转发,下载/上传大文件会有延迟。 # 修改server配置 之前的示例我们使用了默认的http server参数。但这个例子里,我们做一点修改,限制请求的大小,防止被恶意攻击。 ~~~cpp int main(int argc, char *argv[]) { ... struct WFServerParams params = HTTP_SERVER_PARAMS_DEFAULT; params.request_size_limit = 8 * 1024 * 1024; WFHttpServer server(¶ms, process); if (server.start(port) == 0) { pause(); server.stop(); } else { perror("cannot start server"); exit(1); } return 0; } ~~~ 与上一个示例不同,我们在server构造,多传入一个参数结构。我们可以看看http server有哪些配置。 在[WFHttpServer.h](../src/server/WFHttpServer.h)里,http server的默认参数如下: ~~~cpp static constexpr struct WFServerParams HTTP_SERVER_PARAMS_DEFAULT = { .transport_type = TT_TCP, .max_connections = 2000, .peer_response_timeout = 10 * 1000, .receive_timeout = -1, .keep_alive_timeout = 60 * 1000, .request_size_limit = (size_t)-1, .ssl_accept_timeout = 10 * 1000, }; ~~~ transport_type:传输层协议,默认为TCP。除了TT_TCP外,可选择的还有TT_UDP和Linux下支持的TT_SCTP。 max_connections:最大连接数2000,达到上限之后会关闭最久未使用的keep-alive连接。没找到keep-alive连接,则拒绝新连接。 peer_response_timeout:每读取到一块数据或发送出一块数据的超时时间为10秒。 receive_timeout:接收一条完整的请求超时时间为-1,无限。 keep_alive_timeout:连接保持1分钟。 request_size_limit:请求包最大大小,无限制。 ssl_accept_timeout:完成ssl握手超时,10秒。 参数里没有send_timeout,即完整的回复超时。这个参数需要每次请求根据自己回复包的大小来确定。 # 代理服务器业务逻辑 这个代理服务器本质上是将用户请求原封不动转发到对应的web server,再将web server的回复原封不动转发给用户。 浏览器发给proxy的请求里,request uri包含了scheme和host,port,转发时需要去除。 例如,访问`http://www.sogou.com/`, 浏览器发送给proxy请求首行是: `GET` `http://www.sogou.com/` `HTTP/1.1` 需要改写为: `GET` `/` `HTTP/1.1` ~~~cpp void process(WFHttpTask *proxy_task) { auto *req = proxy_task->get_req(); SeriesWork *series = series_of(proxy_task); WFHttpTask *http_task; /* for requesting remote webserver. */ tutorial_series_context *context = new tutorial_series_context; context->url = req->get_request_uri(); context->proxy_task = proxy_task; series->set_context(context); series->set_callback([](const SeriesWork *series) { delete (tutorial_series_context *)series->get_context(); }); http_task = WFTaskFactory::create_http_task(req->get_request_uri(), 0, 0, http_callback); const void *body; size_t len; /* Copy user's request to the new task's reuqest using std::move() */ req->set_request_uri(http_task->get_req()->get_request_uri()); req->get_parsed_body(&body, &len); req->append_output_body_nocopy(body, len); *http_task->get_req() = std::move(*req); /* also, limit the remote webserver response size. */ http_task->get_resp()->set_size_limit(200 * 1024 * 1024); *series << http_task; } ~~~ 以上是process的全部内容。先解析向web server发送的http请求的构造。 req->get_request_uri()调用得到浏览器请求的完整URL,通过这个URL构建发往server的http任务。 这个http任务重试与重定向次数都是0,因为重定向是由浏览器处理,遇到302等会重新发请求。 ~~~cpp req->set_request_uri(http_task->get_req()->get_request_uri()); req->get_parsed_body(&body, &len); req->append_output_body_nocopy(body, len); *http_task->get_req() = std::move(*req); ~~~ 上面4个语句,其实是在生成发往web server的http请求。req是我们收到的http请求,我们最终要通过std::move()把它直接移动到新请求上。 第一行实际上就是将request_uri里的`http://host:port`部分去掉,只保留path之后的部分。 第二第三行把解析下来的http body指定为向外输出的http body。需要做这个操作的原因是,我们的HttpMessage实现里, 解析得到的body和发送请求的body是两个域,所以这里需要简单的设置一下,无需复制内存。 第四行,一次性把请求内容转移给向web server发送的请求。 构造好http请求后,将这个请求放到当前series末尾,process函数结束。 # 异步server的工作原理 显然process函数并不是proxy逻辑的全部,我们还需要处理从web server返回的http response,填写返回给浏览器的response。 在echo server的示例里,我们并不需要进行网络通信,直接填写返回页面就好。但proxy我们需要等待web server的结果。 我们当然可以占用这个process函数的线程,等待结果返回,但这种同步等待的方式明显不是我们想要的。 那么,我们就需要在异步得到请求结果之后,再去回复用户请求,在等待结果期间,不能占用任何的线程。 所以,在process的头部,我们给当前series设置了一个context,context里包含了proxy_task本身,以便我们异步填写结果。 ~~~cpp struct tutorial_series_context { std::string url; WFHttpTask *proxy_task; bool is_keep_alive; }; void process(WFHttpTask *proxy_task) { SeriesWork *series = series_of(proxy_task); ... tutorial_series_context *context = new tutorial_series_context; context->url = req->get_request_uri(); context->proxy_task = proxy_task; series->set_context(context); series->set_callback([](const SeriesWork *series) { delete (tutorial_series_context *)series->get_context(); }); ... } ~~~ 之前client的示例中我们说过,任何一个运行中的任务,都处在一个series里,server任务也不例外。 所以,我们可以得到当前series,并设置context。其中url主要是后续打日志之用,proxy_task是主要内容,后续需要填写resp。 接下来我们就可以看看处理web server响应的部分了。 ~~~cpp void http_callback(WFHttpTask *task) { int state = task->get_state(); auto *resp = task->get_resp(); SeriesWork *series = series_of(task); tutorial_series_context *context = (tutorial_series_context *)series->get_context(); auto *proxy_resp = context->proxy_task->get_resp(); ... if (state == WFT_STATE_SUCCESS) { const void *body; size_t len; /* set a callback for getting reply status. */ context->proxy_task->set_callback(reply_callback); /* Copy the remote webserver's response, to proxy response. */ resp->get_parsed_body(&body, &len); resp->append_output_body_nocopy(body, len); *proxy_resp = std::move(*resp); ... } else { // return a "404 Not found" page ... } } ~~~ 我们只关注成功的情况。一切可以从web server得到一个完整http页面,不管什么返回码,都是成功。所有失败的情况,简单返回一个404页面。 因为返回给用户的数据可能很大,在我们这个示例里,设置为200MB上限。所以,和之前的示例不同,我们需要查看reply成功/失败状态。 http server任务和我们自行创建的http client任务的类型是完全相同的,都是WFHttpTask。不同的是server任务是框架创建的,它的callback初始为空。 server任务的callback和client一样,是在http交互完成之后被调用。所以,对server任务来讲,就是reply完成之后被调用。 后面三行代码我们应该很熟悉了,无拷贝地将web server响应包转移到proxy响应包。 在这个http_callback函数结束之后,对浏览器的回复被发送出,一切都是在异步的过程中进行。 剩下的一个函数是reply_callback(),在这里只为了打印一些log。在这个callback执行结束后,proxy task会被自动delete。 最后,series的callback里销毁context。 # Server回复的时机 这里需要说明一下,回复消息的时机是在series里所有其它任务被执行完后,自动回复,所以并没有task->reply()接口。 但是,有task->noreply()调用,如果对server任务执行了这个调用,在原本回复的时刻,直接关闭连接。但callback依然会被调用(状态为NOREPLY)。 在server任务的callback里,同样可以通过series_of()操作获得任务的series。那么,我们依然可以往这个series里追加新任务,虽然回复已经完成。 workflow-0.11.8/docs/tutorial-06-parallel_wget.md000066400000000000000000000117061476003635400216500ustar00rootroot00000000000000# 一个简单的并行抓取:parallel_wget # 示例代码 [tutorial-06-parallel_wget.cc](/tutorial/tutorial-06-parallel_wget.cc) # 关于parallel_wget 这是我们第一个并行任务的示例。 程序从命令行读入多个http URL(以空格分割),并行抓取这些URL,并按照输入顺序将抓取结果打印到标准输出。 # 创建并行任务 之前的示例里,我们已经接触过了SeriesWork类。 * SeriesWork由任务构成,代表一系列任务的串行执行。所有任务结束,则这个series结束。 * 与SeriesWork对应的ParallelWork类,parallel由series构成,代表若干个series的并行执行。所有series结束,则这个parallel结束。 * ParallelWork是一种任务。 根据上述的定义,我们就可以动态或静态的生成任意复杂的工作流了。 Workflow类里,有两个接口用于产生并行任务: ~~~cpp class Workflow { ... public: static ParallelWork * create_parallel_work(parallel_callback_t callback); static ParallelWork * create_parallel_work(SeriesWork *const all_series[], size_t n, parallel_callback_t callback); ... }; ~~~ 第一个接口创建一个空的并行任务,第二个接口用一个series数组创建并行任务。 无论用哪个接口产生的并行任务,在启动之前都可以用ParallelWork的add_series()接口添加series。 在示例代码里,我们创建一个空的并行任务,并逐个添加series。 ~~~cpp int main(int argc, char *argv[]) { ParallelWork *pwork = Workflow::create_parallel_work(callback); SeriesWork *series; WFHttpTask *task; HttpRequest *req; tutorial_series_context *ctx; int i; for (i = 1; i < argc; i++) { std::string url(argv[i]); ... task = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX, [](WFHttpTask *task) { // store resp to ctx. }); req = task->get_req(); // add some headers. ... ctx = new tutorial_series_context; ctx->url = std::move(url); series = Workflow::create_series_work(task, nullptr); series->set_context(ctx); pwork->add_series(series); } ... } ~~~ 从代码中看到,我们先创建http任务,但http任务并不能直接加入到并行任务里,需要先用它创建一个series。 每个series都带有context,用于保存url和抓取结果。相关的方法我们在之前的示例里都介绍过。 # 保存和使用抓取结果 http任务的callback是一个简单的lambda函数,把抓取结果保存在自己的series context里,以便并行任务获取。 ~~~cpp task = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX, [](WFHttpTask *task) { tutorial_series_context *ctx = (tutorial_series_context *)series_of(task)->get_context(); ctx->state = task->get_state(); ctx->error = task->get_error(); ctx->resp = std::move(*task->get_resp()); }); ~~~ 这个做法是必须的,因为http任务在callback之后就会被回收,我们只能把resp通过std::move()操作移走。 而在并行任务的callback里,我们可以很方便的获得结果: ~~~cpp void callback(const ParallelWork *pwork) { tutorial_series_context *ctx; const void *body; size_t size; size_t i; for (i = 0; i < pwork->size(); i++) { ctx = (tutorial_series_context *)pwork->series_at(i)->get_context(); printf("%s\n", ctx->url.c_str()); if (ctx->state == WFT_STATE_SUCCESS) { ctx->resp.get_parsed_body(&body, &size); printf("%zu%s\n", size, ctx->resp.is_chunked() ? " chunked" : ""); fwrite(body, 1, size, stdout); printf("\n"); } else printf("ERROR! state = %d, error = %d\n", ctx->state, ctx->error); delete ctx; } } ~~~ 在这里,我们看到ParallelWork的两个新接口,size()和series_at(i),分别获得它的并行series个数,和第i个并行series。 通过series->get_context()取到对应series的上下文,打印结果。打印顺序必然和我们放入顺序一致。 在这个示例中,并行任务执行完就没有其它工作了。 我们上面说过,ParallelWork是一种任务,所以同样我们可以用series_of()获得它所在的series并添加新任务。 但是,如果新任务还要使用到抓取结果,我们需要再次用std::move()把数据移到并行任务所在series的上下文里。 # 并行任务启动 并行任务是一种任务,所以并行任务的启动并没有什么特别,可以直接调用start(),也可以用它建立或启动一个series。 在这个示例里,我们启动一个series,在这个series的callback里唤醒主进程,正常退出程序。 我们也可以在并行任务的callback里唤醒主进程,程序行为上区别不大。但在series callback里唤醒更加规范一点。 workflow-0.11.8/docs/tutorial-07-sort_task.md000066400000000000000000000137361476003635400210450ustar00rootroot00000000000000# 使用内置算法工厂:sort_task # 示例代码 [tutorial-07-sort_task.cc](/tutorial/tutorial-07-sort_task.cc) # 关于sort_task 程序从命令行读入数字n,将随机的n个正整数先升序排列,再把结果再降序排列。 程序可加入第二个参数"p",则可以进行并行排序。例如: $ ./sort_task 100000000 p 上面的命令将先升序排列1亿个整数,再降序排列。两次排序都采用并行。 # 关于计算任务 计算任务(或称线程任务),是我们非常重要的一个功能。在使用我们任务流的时候,并不建议在callback里直接进行非常复杂的计算。 所有需要消耗大量CPU时间的计算,都可以封装成计算任务交给系统去调度。计算任务和通信任务在使用方法上并没有什么区别。 系统的算法工厂提供了一些常用的计算任务,比如排序,归并等。用户也可以很方便定义自己的计算任务。 # 创建升序排序任务 ~~~cpp int main(int argc, char *argv[]) { ... WFSortTask *task; if (use_parallel_sort) task = WFAlgoTaskFactory::create_psort_task("sort", array, end, callback); else task = WFAlgoTaskFactory::create_sort_task("sort", array, end, callback); ... task->start(); ... } ~~~ 和WFHttpTask或WFRedisTask不同,排序任务多了一个模板参数代表要排序的数组数据类型。 create_sort_task和create_psort_task分别产生一个普通排序任务和一个并行排序任务。 这两个调用的参数和返回值并没有区别。 唯一需要特别说明的是第一个参数"sort",这个是计算队列名,用于影响内部的任务调度。本篇文档后面会介绍队列名的用法。 计算任务的启动方法与使用方法和网络通信任务并没有什么区别。 # 处理结果 和通信任务一样,我们在callback里处理结果。这个示例里,升序排序之后会再发起一次降序排序。 ~~~cpp using namespace algorithm; void callback(void SortTask *task) { SortInput *input = task->get_input(); int *first = input->first; int *last = input->last; // print result ... if (task->user_data == NULL) { auto cmp = [](int a1, int a2){ return a2 < a1; }; WFSortTask *reverse; if (use_parallel_sort) reverse = WFAlgoTaskFactory::create_psort_task("sort", first, last, cmp, callback); else reverse = WFAlgoTaskFactory::create_sort_task("sort", first, last, cmp, callback); reverse->user_data = (void *)1; /* as a flag */ series_of(task)->push_back(reverse); } else { // all done. Signal main thread to exit. ... } } ~~~ 计算任务的get_input()接口得到输入数据,get_output()得到输出数据。对于排序任务,输入和输出是相同类型,内容也完全相同。 在[WFAlgoTaskFactory.h](../src/factory/WFAlgoTaskFactory.h)里,可以看到排序任务输入输出的定义: ~~~cpp namespace algorithm { template struct SortInput { T *first; T *last; }; template using SortOutput = SortInput; } template using WFSortTask = WFThreadTask, algorithm::SortOutput>; template using sort_callback_t = std::function *)>; ~~~ 显然,input或output里的first, last分别为排序数组的首尾指针。 接下来我们会创建一个降序排序的任务,这时候,我们就需要传进去一个比较函数了。 ~~~cpp auto cmp = [](int a1, int a2)->bool{ return a2 < a1; }; reverse = WFAlgoTaskFactory::create_sort_task("sort", first, last, cmp, callback); ~~~ 可以说我们的用法和std::sort()区别不是很大。但我们的first和last是指针,而不是用iterator。 同样,用create_psort_task()可以创建一个并行排序任务。而对series的使用,和通信任务没有区别。 # 关于计算线程数的配置 如果你不做任何配置,计算调度器将使用当前机器CPU个数的线程数。你也可以通过以下的方式,修改这个值: ~~~cpp #include "workflow/WFGlobal.h" int main() { struct WFGlobalSettings settings = GLOBAL_SETTINGS_DEFAULT; settings.compute_threads = 16; WORKFLOW_library_init(&settings); ... } ~~~ 通过上面的配置,我们将创建16个线程用于计算。 # 关于并行排序算法 内置的并行排序算法,使用分块+二路归并。空间复杂度为O(1)。 算法使用全局配置的计算线程进行计算,但最多使用128个线程。因为不使用额外空间,加速比会小于线程数量,平均CPU占用也比较小。 具体实现可参考[WFAlgoTaskFactory.inl](../src/factory/WFAlgoTaskFactory.inl) # 关于计算队列名 我们的计算任务并没有优化级的概念,唯一可以影响调度顺序的是计算任务的队列名,本示例中队列名为字符串"sort"。 队列名的指定非常简单,需要说明以下几点: * 队列名是一个静态字符串,不可以无限产生新的队列名。例如不可以根据请求id来产生队列名,因为内部会为每个队列分配一小块资源。 * 当计算线程没有被100%占满,所有任务都是实时调起,队列名没有任何影响。 * 如果一个服务流程里有多个计算步骤,穿插在多个网络通信之间,可以简单的给每种计算步骤起一个名字,这个会比整体用一个名字要好。 * 如果所有计算任务用同一个名字,那么所有任务的被调度的顺序与提交顺序一致,在某些场景下会影响平均响应时间。 * 每种计算任务有一个独立名字,那么相当于每种任务之间是公平调度的,而同一种任务内部是顺序调度的,实践效果更好。 * 总之,除非机器的计算负载已经非常繁重,否则没有必要特别关心队列名,只要每种任务起一个名字就可以了。 workflow-0.11.8/docs/tutorial-08-matrix_multiply.md000066400000000000000000000221021476003635400222630ustar00rootroot00000000000000# 自定义计算任务:matrix_multiply # 示例代码 [tutorial-08-matrix_multiply.cc](/tutorial/tutorial-08-matrix_multiply.cc) # 关于matrix_multiply 程序执行代码里两个矩阵的乘法,并将相乘结果打印在屏幕上。 示例的主要目的是展现怎么实现一个自定义CPU计算任务。 # 定义计算任务 定义计算任务需要提供3个基本信息,分别为INPUT,OUTPUT,和routine。 INPUT和OUTPUT是两个模板参数,可以是任何类型。routine表示从INPUT到OUTPUT的过程,定义如下: ~~~cpp template class __WFThreadTask { ... std::function routine; ... }; ~~~ 可以看出routine是一个简单的从INPUT到OUTPUT的计算过程。INPUT指针不要求是const,但用户也可以传const INPUT *的函数。 比如一个加法任务,就可这么做: ~~~cpp struct add_input { int x; int y; }; struct add_ouput { int res; }; void add_routine(const add_input *input, add_output *output) { output->res = input->x + input->y; } typedef WFThreadTask add_task; ~~~ 在我们的矩阵乘法的示例里,输入是两个矩阵,输出为一个矩阵。其定义如下: ~~~cpp namespace algorithm { using Matrix = std::vector>; struct MMInput { Matrix a; Matrix b; }; struct MMOutput { int error; size_t m, n, k; Matrix c; }; void matrix_multiply(const MMInput *in, MMOutput *out) { ... } } ~~~ 矩阵乘法存在有输入矩阵不合法的问题,所以output里多了一个error域,用来表示错误。 # 生成计算任务 定义好输入输出的类型,以及算法的过程之后,就可以通过WFThreadTaskFactory工厂来产生计算任务了。 在[WFTaskFactory.h](../src/factory/WFTaskFactory.h)里,计算工厂类的定义如下: ~~~cpp template class WFThreadTaskFactory { private: using T = WFThreadTask; public: static T *create_thread_task(const std::string& queue_name, std::function routine, std::function callback); static T *create_thread_task(time_t seconds, long nanoseconds, const std::string& queue_name, std::function routine, std::function callback); ... }; ~~~ 这里包含两个创建任务的接口。第二个接口支持用户传入一下任务运行时间限制,我们在一下节介绍这个功能。 与之前的网络工厂类或算法工厂类略有不同,这个类需要INPUT和OUTPUT两个模板参数。 queue_name相关的知识在上一个示例里已经有介绍。routine就是你的计算过程,callback是回调。 在我们的示例里,我们看到了这个调用的使用: ~~~cpp using MMTask = WFThreadTask; using namespace algorithm; int main() { typedef WFThreadTaskFactory MMFactory; MMTask *task = MMFactory::create_thread_task("matrix_multiply_task", matrix_multiply, callback); MMInput *input = task->get_input(); input->a = {{1, 2, 3}, {4, 5, 6}}; input->b = {{7, 8}, {9, 10}, {11, 12}}; ... } ~~~ 产生了task之后,通过get_input()接口得到输入数据的指针。这个可以类比网络任务的get_req()。 任务的发起和结束什么,与网络任务并没有什么区别。同样,回调也很简单: ~~~cpp void callback(MMTask *task) // MMtask = WFThreadTask { MMInput *input = task->get_input(); MMOutput *output = task->get_output(); assert(task->get_state() == WFT_STATE_SUCCESS); if (output->error) printf("Error: %d %s\n", output->error, strerror(output->error)); else { printf("Matrix A\n"); print_matrix(input->a, output->m, output->k); printf("Matrix B\n"); print_matrix(input->b, output->k, output->n); printf("Matrix A * Matrix B =>\n"); print_matrix(output->c, output->m, output->n); } } ~~~ 普通的计算任务可以忽略失败的可能性,结束状态肯定是SUCCESS。 callback里简单打印了输入输出。如果输入数据不合法,则打印错误。 # 带运行时间限制的计算任务 显然,我们的框架无法打断用户的计算任务,因为用户的计算任务是一个函数,用户需要自行确保函数可以正常结束。 但我们支持用户指定一个时间限制,当计算无法在指定时间内完成,任务可以提前回到callback。带运行时间限制的接口定义如下: ~~~cpp template class WFThreadTaskFactory { private: using T = WFThreadTask; public: static T *create_thread_task(time_t seconds, long nanoseconds, const std::string& queue_name, std::function routine, std::function callback); ... }; ~~~ 参数seconds和nanoseconds构成了运行时限。在这里,nanoseconds的取值范围在\[0,1000000000)。 当任务无法在运行时限内结束,会直接回到callback,并且任务的状态为WFT_STATE_SYS_ERROR且错误码为ETIMEDOUT。 还是用matrix_multiply的例子,我们可以这样写: ~~~cpp void callback(MMTask *task) // MMtask = WFThreadTask { MMInput *input = task->get_input(); MMOutput *output = task->get_output(); if (task->get_state() == WFT_STATE_SYS_ERROR && task->get_error() == ETIMEDOUT) { printf("Run out of time.\n"); return; } assert(task->get_state() == WFT_STATE_SUCCESS) if (output->error) printf("Error: %d %s\n", output->error, strerror(output->error)); else { printf("Matrix A\n"); print_matrix(input->a, output->m, output->k); printf("Matrix B\n"); print_matrix(input->b, output->k, output->n); printf("Matrix A * Matrix B =>\n"); print_matrix(output->c, output->m, output->n); } } using namespace algorithm; int main() { typedef WFThreadTaskFactory MMFactory; MMTask *task = MMFactory::create_thread_task(0, 1000000, "matrix_multiply_task", matrix_multiply, callback); MMInput *input = task->get_input(); input->a = {{1, 2, 3}, {4, 5, 6}}; input->b = {{7, 8}, {9, 10}, {11, 12}}; ... } ~~~ 上面的示例,限制了任务运行时间不超过1毫秒,否则,以WFT_STATE_SYS_ERROR的状态返回。 再次提醒,我们并不会中断用户的实际运行函数。当任务超时并callback,计算函数还会一直运行直到结束。 如果用户希望函数不再继续执行,需要在代码中自行加入检查点来实现这样的功能。可以在INPUT里加入flag,例如: ~~~cpp void callback(MMTask *task) // MMtask = WFThreadTask { if (task->get_state() == WFT_STATE_SYS_ERROR && task->get_error() == ETIMEDOUT) { task->get_input()->flag = true; printf("Run out of time.\n"); return; } ... } void matrix_multiply(const MMInput *in, MMOutput *out) { while (!in->flag) { .... } } ~~~ # 算法与协议的对称性 在我们的体系里,算法与协议在一个非常抽象的层面上是具有高度对称性的。 有自定义算法的线程任务,那显然也存在自定义协议的网络任务。 自定义算法要求提供算法的过程,而自定义协议则需要用户提供序列化和反序列化的过程,[简单的用户自定义协议client/server](./tutorial-10-user_defined_protocol.md)有介绍。 无论是自定义算法还是自定义协议,我们都必须强调算法和协议都是非常纯粹的。 例如算法就是一个从INPUT到OUPUT的转换过程,算法并不知道task,series等的存在。 HTTP协议的实现上,也只关心序列化反序列化,无需要关心什么是task。而是在http task里去引用HTTP协议。 # 线程任务与网络任务的复合性 在这个示例里,我们通过WFThreadTaskFactory构建了一个线程任务。可以说这是一种最简单的计算任务构建,大多数情况下也够用了。 同样,用户可以非常简单的定义一个自有协议的server和client。 但在上一个示例里我们看到,我们可以通过算法工厂产生一个并行排序任务,这显然不是通过一个routine就能做到的。 对于网络任务,比如一个kafka任务,可能要经过与多台机器的交互才能得到结果,但对用户来讲是完全透明的。 所以,我们的任务都是具有复合性的,如果你熟练使用我们的框架,可以设计出很多复杂的组件出来。 workflow-0.11.8/docs/tutorial-09-http_file_server.md000066400000000000000000000176121476003635400223770ustar00rootroot00000000000000# 异步IO的http server:http_file_server # 示例代码 [tutorial-09-http_file_server.cc](/tutorial/tutorial-09-http_file_server.cc) # 关于http_file_server http_file_server是一个web服务器,用户指定启动端口,根路径(默认为程序当路程),就可以启动一个web server。 用户还可以指定一个PEM格式的certificate file和key file,启动一个https web server。 程序在启动server之后,可以从命令行接受用户输入,并通过127.0.0.1地址来访问这个server。 程序主要展示了磁盘IO任务的用法。在Linux系统下,我们利用了Linux底层的aio接口,文件读取完全异步。 # 启动server 启动server这块,和之前的echo server或http proxy没有什么大区别。在这里只是多了一种SSL server的启动方式: ~~~cpp class WFServerBase { ... int start(unsigned short port, const char *cert_file, const char *key_file); ... }; ~~~ 也就是说,start操作可以指定一个PEM格式的cert文件和key文件,启动一个SSL server。 此外,我们在定义server时,用std::bind()给process绑定了一个root参数,代表服务的根路径。 ~~~cpp void process(WFHttpTask *server_task, const char *root) { ... } int main(int argc, char *argv[]) { ... const char *root = (argc >= 3 ? argv[2] : "."); auto&& proc = std::bind(process, std::placeholders::_1, root); WFHttpServer server(proc); // start server ... } ~~~ # 处理请求 与http_proxy类似,我们不占用任何线程读取文件,而是产生一个异步的读文件任务,在读取完成之后回复请求。 再次说明一下,我们需要把完整回复数据读取到内存,才开始回复消息。所以不适合用来传输太大的文件。 ~~~cpp void process(WFHttpTask *server_task, const char *root) { // generate abs path. ... int fd = open(abs_path.c_str(), O_RDONLY); if (fd >= 0) { size_t size = lseek(fd, 0, SEEK_END); void *buf = malloc(size); /* As an example, assert(buf != NULL); */ WFFileIOTask *pread_task; pread_task = WFTaskFactory::create_pread_task(fd, buf, size, 0, pread_callback); /* To implement a more complicated server, please use series' context * instead of tasks' user_data to pass/store internal data. */ pread_task->user_data = resp; /* pass resp pointer to pread task. */ server_task->user_data = buf; /* to free() in callback() */ server_task->set_callback([](WFHttpTask *t){ free(t->user_data); }); series_of(server_task)->push_back(pread_task); } else { resp->set_status_code("404"); resp->append_output_body("404 Not Found."); } } ~~~ 与http_proxy产生一个新的http client任务不同,这里我们通过factory产生了一个pread任务。 在[WFTaskFactory.h](../src/factory/WFTaskFactory.h)里,我们可以看到相关的接口。 ~~~cpp struct FileIOArgs { int fd; void *buf; size_t count; off_t offset; }; ... using WFFileIOTask = WFFileTask; using fio_callback_t = std::function; ... class WFTaskFactory { public: ... static WFFileIOTask *create_pread_task(int fd, void *buf, size_t count, off_t offset, fio_callback_t callback); static WFFileIOTask *create_pwrite_task(int fd, void *buf, size_t count, off_t offset, fio_callback_t callback); ... /* Interface with file path name */ static WFFileIOTask *create_pread_task(const std::string& path, void *buf, size_t count, off_t offset, fio_callback_t callback); static WFFileIOTask *create_pwrite_task(const std::string& path, void *buf, size_t count, off_t offset, fio_callback_t callback); }; ~~~ 无论是pread还是pwrite,返回的都是WFFileIOTask。这与不区分sort或psort,不区分client或server task是一个道理。 除这两个接口还有preadv和pwritev,返回WFFileVIOTask,以及fsync,fdsync,返回WFFileSyncTask。可以在头文件里查看。 示例用了task的user_data域保存服务的全局数据。但对于大服务,我们推荐使用series context。可以参考前面的[proxy示例](../tutorial/tutorial-05-http_proxy.cc)。 # 处理读文件结果 ~~~cpp using namespace protocol; void pread_callback(WFFileIOTask *task) { FileIOArgs *args = task->get_args(); long ret = task->get_retval(); HttpResponse *resp = (HttpResponse *)task->user_data; /* close fd only when you created File IO task with **fd** interface. */ close(args->fd); if (ret < 0) { resp->set_status_code("503"); resp->append_output_body("503 Internal Server Error."); } else /* Use '_nocopy' carefully. */ resp->append_output_body_nocopy(args->buf, ret); } ~~~ 文件任务的get_args()得到输入参数,这里是FileIOArgs结构,如果是用文件路径名创建的文件任务,其中的fd域等于-1。 get_retval()是操作的返回值。当ret < 0, 任务错误。否则ret为读取到数据的大小。 在文件任务里,ret < 0与task->get_state() != WFT_STATE_SUCCESS完全等价。 buf域的内存我们是自己管理的,可以通过append_output_body_nocopy()传给resp。 在回复完成后,我们会free()这块内存,这个语句在process里: server_task->set_callback([](WFHttpTask *t){ free(t->user_data); }); # 命令行交互 启动server后,用户可以在控制台输入文件名来访问server。当输入文件名为空(Ctrl-D),关闭server并结束程序。 这里,我们使用了WFRepeaterTask来实现这个循环接受输入的过程。WFRepeaterTask是一种循环任务,产生的接口如下: ~~~cpp using repeated_create_t = std::function; using repeater_callback_t = std::function; class WFTaskFactory { WFRpeaterTask *create_repeater_task(repeated_create_t create, repeater_callback_t callback); }; ~~~ 通过create函数,可以创建一个repeater任务。repeater内部会反复调用create,产生一个任务并运行,直到create返回空指针。 在我们的这个示例里,create函数内部调用scanf。当用户输入为空时,create返回NULL,整个循环过程结束。 当用户输入不为空(文件名),产生一个访问127.0.0.1地址的http任务来访问我们开启的server。 ~~~cpp { auto&& create = [&scheme, port](WFRepeaterTask *)->SubTask *{ ... scanf("%1023s", buf); if (*buf == '\0') return NULL; std::string url = scheme + "127.0.0.1:" + std::to_string(port) + "/" + buf; WFHttpTask *task = WFTaskFactory::create_http_task(url, 0, 0, [](WFHttpTask *task) { ... }); return task; }; WFFacilities::WaitGroup wg(1); WFRepeaterTask *repeater; repeater = WFTaskFactory::create_repeater_task(create, [&wg](WFRepeaterTask *) { wg.done(); }); repeater->start(); wg.wait(); server.stop(); } ~~~ 最后,当create返回NULL,repeater被callback。我们关闭server并结束程序。 # 关于文件异步IO的实现 Linux操作系统支持一套效率很高,CPU占用非常少的异步IO系统调用。在Linux系统下使用我们的框架将默认使用这套接口。 我们曾经实现过一套posix aio接口用于支持其它UNIX系统,并使用线程的sigevent通知方式,但由于其效率太低,已经不再使用了。 目前,对于非Linux系统,异步IO一律是用多线程实现,在IO任务到达时,实时创建线程执行IO任务,callback回到handler线程池。 多线程IO也是macOS下的唯一选择,因为macOS没有良好的sigevent支持,posix aio行不通。 某些UNIX系统不支持fdatasync调用,这种情况下,fdsync任务将等价于fsync任务。 workflow-0.11.8/docs/tutorial-10-user_defined_protocol.md000066400000000000000000000274601476003635400234020ustar00rootroot00000000000000# 简单的用户自定义协议client/server # 示例代码 [message.h](/tutorial/tutorial-10-user_defined_protocol/message.h) [message.cc](/tutorial/tutorial-10-user_defined_protocol/message.cc) [server.cc](/tutorial/tutorial-10-user_defined_protocol/server.cc) [client.cc](/tutorial/tutorial-10-user_defined_protocol/client.cc) # 关于user_defined_protocol 本示例设计一个简单的通信协议,并在协议上构建server和client。server将client发送的消息转换成大写并返回。 # 协议的格式 协议消息包含一个4字节的head和一个message body。head是一个网络序的整数,指明body的长度。 请求和响应消息的格式一致。 # 协议的实现 用户自定义协议,需要提供协议的序列化和反序列化方法,这两个方法都是ProtocolMessage类的虚函数。 另外,为了使用方便,我们强烈建议用户实现消息的移动构造和移动赋值(用于std::move())。 在[ProtocolMessage.h](../src/protocol/ProtocolMessage.h)里,序列化反序列化接口如下: ~~~cpp namespace protocol { class ProtocolMessage : public CommMessageOut, public CommMessageIn { private: virtual int encode(struct iovec vectors[], int max); /* You have to implement one of the 'append' functions, and the first one * with arguement 'size_t *size' is recommmended. */ virtual int append(const void *buf, size_t *size); virtual int append(const void *buf, size_t size); ... }; } ~~~ ### 序列化函数encode * encode函数在消息被发送之前调用,每条消息只调用一次。 * encode函数里,用户需要将消息序列化到一个vector数组,数组元素个数不超过max。目前max的值为2048。 * 结构体struct iovec定义在请参考系统调用readv和writev。 * encode函数正确情况下的返回值在0到max之间,表示消息使用了多少个vector。 * 如果是UDP协议,请注意总长度不超过64k,并且使用不超过1024个vector(Linux一次writev只能1024个vector)。 * encode返回-1表示错误。返回-1时,需要置errno。如果返回值>max,将得到一个EOVERFLOW错误。错误都在callback里得到。 * 为了性能考虑vector里的iov_base指针指向的内容不会被复制。所以一般指向消息类的成员。 ### 反序列化函数append * append函数在每次收到一个数据块时被调用。因此,每条消息可能会调用多次。 * buf和size分别是收到的数据块内容和长度。用户需要把数据内容复制走。 * 如果实现了append(const void \*buf, size_t \*size)接口,可以通过修改\*size来告诉框架本次消费了多少长度。收到的size - 消费的size = 剩余的size,剩余的那部分buf会由下一次append被调起时再次收到。此功能更方便协议解析,当然用户也可以全部复制走自行管理,则无需修改\*size。 * append函数返回0表示消息还不完整,传输继续。返回1表示消息结束。-1表示错误,需要置errno。 * 总之append的作用就是用于告诉框架消息是否已经传输结束。不要在append里做复杂的非必要的协议解析。 ### errno的设置 * encode或append返回-1或其它负数都会被理解为失败,需要通过errno来传递错误原因。用户会在callback里得到这个错误。 * 如果是系统调用或libc等库函数失败(比如malloc),libc肯定会设置好errno,用户无需再设置。 * 一些消息不合法的错误是比较常见的,比如可以用EBADMSG,EMSGSIZE分别表示消息内容错误,和消息太大。 * 用户可以选择超过系统定义errno范围的值来表示一些自定义错误。一般大于256的值是可以用的。 * 请不要使用负数errno。因为框架内部用了负数来代表SSL错误。 在我们的示例里,消息的序列化反序列化都非常的简单。 头文件[message.h](../tutorial/tutorial-10-user_defined_protocol/message.h)里,声明了request和response类: ~~~cpp namespace protocol { class TutorialMessage : public ProtocolMessage { private: virtual int encode(struct iovec vectors[], int max); virtual int append(const void *buf, size_t size); ... }; using TutorialRequest = TutorialMessage; using TutorialResponse = TutorialMessage; } ~~~ request和response类,都是同一种类型的消息。直接using就可以。 注意request和response必须可以无参数的被构造,也就是说需要有无参数的构造函数,或完全没有构造函数。 此外,通讯过程中,如果发生重试,response对象会被销毁并重新构造。因此,它最好是一个RAII类。否则处理起来会比较复杂。 [message.cc](../tutorial/tutorial-10-user_defined_protocol/message.cc)里包含了encode和append的实现: ~~~cpp namespace protocol { int TutorialMessage::encode(struct iovec vectors[], int max/*max==8192*/) { uint32_t n = htonl(this->body_size); memcpy(this->head, &n, 4); vectors[0].iov_base = this->head; vectors[0].iov_len = 4; vectors[1].iov_base = this->body; vectors[1].iov_len = this->body_size; return 2; /* return the number of vectors used, no more then max. */ } int TutorialMessage::append(const void *buf, size_t size) { if (this->head_received < 4) { size_t head_left; void *p; p = &this->head[this->head_received]; head_left = 4 - this->head_received; if (size < 4 - this->head_received) { memcpy(p, buf, size); this->head_received += size; return 0; } memcpy(p, buf, head_left); size -= head_left; buf = (const char *)buf + head_left; p = this->head; this->body_size = ntohl(*(uint32_t *)p); if (this->body_size > this->size_limit) { errno = EMSGSIZE; return -1; } this->body = (char *)malloc(this->body_size); if (!this->body) return -1; this->body_received = 0; } size_t body_left = this->body_size - this->body_received; if (size > body_left) { errno = EBADMSG; return -1; } memcpy(this->body, buf, size); if (size < body_left) return 0; return 1; } } ~~~ encode的实现非常简单,固定使用了两个vector,分别指向head和body。需要注意iov_base指针必须指向消息类的成员。 append需要保证4字节的head接收完整,再读取message body。而且我们并不能保证第一次append一定包含完整的head,所以过程略为繁琐。 append实现了size_limit功能,超过size_limit的会返回EMSGSIZE错误。用户如果不需要限制消息大小,可以忽略size_limit这个域。 由于我们要求通信协议是一来一回的,所谓的“TCP黏包”问题不需要考虑,直接当错误消息处理。 现在,有了消息的定义和实现,我们就可以建立server和client了。  # server和client的定义 有了request和response类,我们就可以建立基于这个协议的server和client。前面的示例里我们介绍过Http协议相关的类型定义: ~~~cpp using WFHttpTask = WFNetworkTask; using http_callback_t = std::function; using WFHttpServer = WFServer; using http_process_t = std::function; ~~~ 同样的,对这个Tutorial协议,数据类型的定义并没有什么区别: ~~~cpp using WFTutorialTask = WFNetworkTask; using tutorial_callback_t = std::function; using WFTutorialServer = WFServer; using tutorial_process_t = std::function; ~~~ # server端 server与普通的http server没有什么区别。我们优先IPv6启动,这不影响IPv4的client请求。另外限制请求最多不超过4KB。 代码请自行参考[server.cc](../tutorial/tutorial-10-user_defined_protocol/server.cc) # client端 client端的逻辑是从标准IO接收用户输入,构造出请求发往server并得到结果。这里我们使用了WFRepeaterTask来实现这个重复过程,直到用户的输入为空。 此外,为了安全我们限制server回复包不超4KB。 client端唯一需要了解的就是怎么产生一个自定义协议的client任务,在[WFTaskFactory.h](../src/factory/WFTaskFactory.h)有四个接口可以选择: ~~~cpp template class WFNetworkTaskFactory { private: using T = WFNetworkTask; public: static T *create_client_task(TransportType type, const std::string& host, unsigned short port, int retry_max, std::function callback); static T *create_client_task(TransportType type, const std::string& url, int retry_max, std::function callback); static T *create_client_task(TransportType type, const ParsedURI& uri, int retry_max, std::function callback); static T *create_client_task(TransportType type, const struct sockaddr *addr, socklen_t addrlen, int retry_max, std::function callback); ... }; ~~~ 其中,TransportType指定传输层协议,目前可选的值包括TT_TCP,TT_UDP,TT_SCTP和TT_TCP_SSL。 四个接口的区别不大,在我们这个示例里暂时不需要URL,我们用域名和端口来创建任务。 如果用户需要使用Unix Domain Protocol访问server,则需要用最后一个接口,直接传入sockaddr。 实际的调用代码如下。我们派生了WFTaskFactory类,但这个派生并非必须的。 ~~~cpp using namespace protocol; class MyFactory : public WFTaskFactory { public: static WFTutorialTask *create_tutorial_task(const std::string& host, unsigned short port, int retry_max, tutorial_callback_t callback) { using NTF = WFNetworkTaskFactory; WFTutorialTask *task = NTF::create_client_task(TT_TCP, host, port, retry_max, std::move(callback)); task->set_keep_alive(30 * 1000); return task; } }; ~~~ 可以看到我们用了WFNetworkTaskFactory类来创建client任务。 接下来通过任务的set_keep_alive()接口,让连接在通信完成之后保持30秒,否则,将默认采用短连接。 client的其它代码涉及的知识点在之前的示例里都包含了。请参考[client.cc](../tutorial/tutorial-10-user_defined_protocol/client.cc) # 内置协议的请求是怎么产生的 现在系统中内置了http, redis,mysql,kafka,dns等协议。我们可以通过相同的方法产生一个http或redis任务吗?比如: ~~~cpp WFHttpTask *task = WFNetworkTaskFactory::create_client_task(...); ~~~ 需要说明的是,这样产生的http任务,会损失很多的功能,比如,无法根据header来识别是否用持久连接,无法识别重定向等。 同样,如果这样产生一个MySQL任务,可能根本就无法运行起来。因为缺乏登录认证过程。 一个kafka请求可能需要和多台broker有复杂的交互过程,这样创建的请求显然也无法完成这一过程。 可见每一种内置协议消息的产生过程都远远比这个示例复杂。同样,如果用户需要实现一个更多功能的通信协议,还有许多代码要写。 workflow-0.11.8/docs/tutorial-11-graph_task.md000066400000000000000000000075561476003635400211550ustar00rootroot00000000000000# 有向无环图(DAG)的使用:graph_task # 示例代码 [tutorial-11-graph_task.cc](/tutorial/tutorial-11-graph_task.cc) # 关于graph_task graph_task示例通过建立一个有向无环图,演示如何用workflow框架实现更加复杂的任务间依赖关系。 # 创建DAG中的任务 DAG中的任务,可以是workflow框架的任何一种任务。在本示例中,我们创建了一个timer任务,两个http任务,以及一个go任务。 Timer执行一秒的等待,http1和http2分别抓取sogou和baidu的首页,go任务打印结果。它们之间的依赖关系如下: ~~~ +-------+ +---->| Http1 |-----+ | +-------+ | +-------+ +-v--+ | Timer | | Go | +-------+ +-^--+ | +-------+ | +---->| Http2 |-----+ +-------+ ~~~ 创建DAG中任务的方法与创建普通任务的方法没有区别,这里不再展开。 # 创建图任务 DAG图在我们的框架里也是一种任务,通过以下代码,我们可以创建一个图任务: ~~~cpp { WFGraphTask *graph = WFTaskFactory::create_graph_task([](WFGraphTask *) { printf("Graph task complete. Wakeup main process\n"); wait_group.done(); }); } ~~~ 可以看到,图任务的类型为WFGraphTask,创建函数只有一个参数,即任务的回调。显然一个新建的图任务,是一张空图。 # 创建图节点 接下来,我们需要通过之前创建的4个普通任务(timer,http_task1,http_task2,go_task),产生4个图节点: ~~~cpp { /* Create graph nodes */ WFGraphNode& a = graph->create_graph_node(timer); WFGraphNode& b = graph->create_graph_node(http_task1); WFGraphNode& c = graph->create_graph_node(http_task2); WFGraphNode& d = graph->create_graph_node(go_task); } ~~~ WFGraphTask的create_graph_node接口,产生一个图节点并返回节点的引用,用户通过这个节点引用来建立节点之间的依赖。 如果我们不为节点建立依赖直接运行图任务,那么显然所有节点都是孤立节点,将全部并发执行。 # 建立依赖 通过非常形象的'-->'运算符,我们可以建立节点的依赖关系: ~~~cpp { /* Build the graph */ a-->b; a-->c; b-->d; c-->d; } ~~~ 这样我们就建立起了上述结构的DAG图啦。 除’—>’运算符,我们同样支持’<—‘。并且它们都可以连着写。所以,以下程序都是合法且等价的: ~~~cpp { a-->b-->d; a-->c-->d; } ~~~ ~~~cpp { d<--b<--a; d<--c<--a; } ~~~ ~~~cpp { d<--b<--a-->c-->d; } ~~~ 接下来直接运行graph,或者把graph放入任务流中就可以运行啦,和一般的任务没有区别。 当然,把一个图任务变成另一个图的节点,也是完全正确的行为。 # 取消后继节点 在图任务里,我们扩展了series的cancel操作,这个操作会取消该节点的所有后继结点。 取消操作一般在节点任务的callback里执行,例如: ~~~cpp int main() { WFGraphTask *graph = WFTaskFactory::create_graph_task(graph_callback); WFHttpTask *task = WFTaskFactory::create_http_task(url, 0, 0, [](WFHttpTask *t){ if (t->get_state() != WFT_STATE_SUCCESS) series_of(t)->cancel(); }); WFGraphNode& a = graph->create_graph_node(task); WFGraphNode& b = ...; WFGraphNode& c = ...; WFGraphNode& d = ...; a-->b-->c; b-->d; graph->start(); ... } ~~~ 注意取消后继节点的操作是递归的,这个例子里,如果http任务失败,b,c,d三个节点的任务都会被取消。 # 数据传递 图节点之间目前没有统一的数据传递方法,它们并不共享某一个series。因此,节点间数据传递需要用户解决。 # 致谢 部分思路来自于[taskflow](https://github.com/taskflow/taskflow)项目。 workflow-0.11.8/docs/tutorial-12-mysql_cli.md000066400000000000000000000366741476003635400210320ustar00rootroot00000000000000# 异步MySQL客户端:mysql_cli # 示例代码 [tutorial-12-mysql_cli.cc](/tutorial/tutorial-12-mysql_cli.cc) # 关于mysql_cli 教程中的mysql_cli使用方式与官方客户端相似,是一个命令行交互式的异步MySQL客户端。 程序运行方式:./mysql_cli \ 启动之后可以直接在终端输入mysql命令与db进行交互,输入quit或Ctrl-C退出。 # MySQL URL的格式 mysql://username:password@host:port/dbname?character_set=charset&character_set_results=charset - 如果以SSL连接访问MySQL,则scheme设为**mysqls://**。MySQL server 5.7及以上支持; - username和password按需填写,如果密码里包含特殊字符,需要转义后再拼接URL; ~~~cpp // 密码为:@@@@#### std::string url = "mysql://root:" + StringUtil::url_encode_component("@@@@####") + "@127.0.0.1"; ~~~ - port默认为3306; - dbname为要用的数据库名,一般如果SQL语句只操作一个db的话建议填写; - 如果用户在这一层有upstream选取需求,可以参考[upstream文档](/docs/about-upstream.md); - character_set为client的字符集,等价于使用官方客户端启动时的参数``--default-character-set``的配置,默认utf8,具体可以参考MySQL官方文档[character-set.html](https://dev.mysql.com/doc/internals/en/character-set.html)。 - character_set_results为client、connection和results的字符集,如果想要在SQL语句里使用``SET NAME``来指定这些字符集的话,请把它配置到url的这个位置。 MySQL URL示例: mysql://root:password@127.0.0.1 mysql://@test.mysql.com:3306/db1?character_set=utf8&character_set_results=utf8 mysqls://localhost/db1?character\_set=big5 # 创建并启动MySQL任务 用户可以使用WFTaskFactory创建MySQL任务,创建接口与回调函数的用法都与workflow其他任务类似: ~~~cpp using mysql_callback_t = std::function; WFMySQLTask *create_mysql_task(const std::string& url, int retry_max, mysql_callback_t callback); void set_query(const std::string& query); ~~~ 用户创建完WFMySQLTask之后,可以对req调用 **set_query()** 写入SQL语句。 如果没调用过 **set_query()** ,task就被start起来的话,则用户会在callback里得到**WFT_ERR_MYSQL_QUERY_NOT_SET**。 其他包括callback、series、user_data等与workflow其他task用法类似。 大致使用示例如下: ~~~cpp int main(int argc, char *argv[]) { ... WFMySQLTask *task = WFTaskFactory::create_mysql_task(url, RETRY_MAX, mysql_callback); task->get_req()->set_query("SHOW TABLES;"); ... task->start(); ... } ~~~ # 支持的命令 目前支持的命令为**COM_QUERY**,已经能涵盖用户基本的增删改查、建库删库、建表删表、预处理、使用存储过程和使用事务的需求。 因为我们的交互命令中不支持选库(**USE**命令),所以,如果SQL语句中有涉及到**跨库**的操作,则可以通过**db_name.table_name**的方式指定具体哪个库的哪张表。 其他所有命令都可以**拼接**到一起通过 ``set_query()`` 传给WFMySQLTask(包括INSERT/UPDATE/SELECT/DELETE/PREPARE/CALL)。 拼接的命令会被按序执行直到命令发生错误,前面的命令会执行成功。 举个例子: ~~~cpp req->set_query("SELECT * FROM table1; CALL procedure1(); INSERT INTO table3 (id) VALUES (1);"); ~~~ # 结果解析 与workflow其他任务类似,可以用``task->get_resp()``拿到**MySQLResponse**,我们可以通过**MySQLResultCursor**遍历结果集。具体接口可以查看:[MySQLResult.h](/src/protocol/MySQLResult.h) 一次请求所对应的回复中,其数据是一个三维结构: - 一个回复中包含了一个或多个结果集(result set); - 一个结果集的类型可能是**MYSQL_STATUS_GET_RESULT**或者**MYSQL_STATUS_OK**; - **MYSQL_STATUS_GET_RESULT**类型的结果集包含了一行或多行(row); - 一行包含了一列或多个列,或者说一到多个阈(Field/Cell),具体数据结构为**MySQLField**和**MySQLCell**; 结果集的两种类型,可以通过``cursor->get_cursor_status()``进行判断: | |MYSQL_STATUS_GET_RESULT|MYSQL_STATUS_OK| |------|-----------------------|---------------| |SQL命令|SELECT(包括存储过程中的每一个SELECT)|INSERT / UPDATE / DELETE / ...| |对应语义|读操作,一个结果集表示一份读操作返回的二维表|写操作,一个结果集表示一个写操作是否成功| |主要接口|fetch_fields();
fetch_row(&row_arr);
...|get_insert_id();
get_affected_rows();
...| 由于拼接语句可能存在错误,因此这种情况,可以通过**MySQLResultCursor**拿到前面正确执行过的语句多个结果集,以及最后判断``resp->get_packet_type()``为**MYSQL_PACKET_ERROR**时,通过``resp->get_error_code()``和``resp->get_error_msg()``拿到具体错误信息。 一个包含n条**SELECT**语句的**存储过程**,会返回n个**MYSQL_STATUS_GET_RESULT**的结果集和1个**MYSQL_STATUS_OK**的结果集,用户自行忽略此**MYSQL_STATUS_OK**结果集即可。 具体使用从外到内的步骤应该是: 1. 判断任务状态(代表通信层面状态):用户通过判断 ``task->get_state()`` 等于**WFT_STATE_SUCCESS**来查看任务执行是否成功; 2. 判断回复包类型(代表返回包解析状态):调用 **resp->get_packet_type()** 查看最后一条MySQL语句的返回包类型,常见的几个类型为: - MYSQL_PACKET_OK:成功,可以用cursor遍历结果; - MYSQL_PACKET_EOF:成功,可以用cursor遍历结果; - MYSQL_PACKET_ERROR:失败或部分失败,成功的部分可以用cursor遍历结果; 3. 遍历结果集。用户可以使用**MySQLResultCursor**读取结果集中的内容,因为MySQL server返回的数据是多结果集的,因此一开始cursor会**自动指向第一个结果集**的读取位置。 4. 判断结果集状态(代表结果集读取状态):通过 ``cursor->get_cursor_status()`` 可以拿到的几种状态: - MYSQL_STATUS_GET_RESULT:此结果集为读请求类型; - MYSQL_STATUS_END:读结果集已读完最后一行; - MYSQL_STATUS_OK:此结果集为写请求类型; - MYSQL_STATUS_ERROR:解析错误; 5. 读取**MYSQL_STATUS_OK**结果集中的基本内容: - ``unsigned long long get_affected_rows() const;`` - ``unsigned long long get_insert_id() const;`` - ``int get_warnings() const;`` - ``std::string get_info() const;`` 6. 读取**MYSQL_STATUS_GET_RESULT**结果集中的columns中每个field: - ``int get_field_count() const;`` - ``const MySQLField *fetch_field();`` - ``const MySQLField *const *fetch_fields() const;`` 7. 读取**MYSQL_STATUS_GET_RESULT**结果集中的每一行:按行读取可以使用 ``cursor->fetch_row()`` 直到返回值为false。其中会移动cursor内部对当前结果集的指向每行的offset: - ``int get_rows_count() const;`` - ``bool fetch_row(std::vector& row_arr);`` - ``bool fetch_row(std::map& row_map);`` - ``bool fetch_row(std::unordered_map& row_map);`` - ``bool fetch_row_nocopy(const void **data, size_t *len, int *data_type);`` 8. 直接把当前**MYSQL_STATUS_GET_RESULT**结果集的所有行拿出:所有行的读取可以使用 **cursor->fetch_all()** ,内部用来记录行的cursor会直接移动到最后;当前cursor状态会变成**MYSQL_STATUS_END**: - ``bool fetch_all(std::vector>& rows);`` 9. 返回当前**MYSQL_STATUS_GET_RESULT**结果集的头部:如果有必要重读这个结果集,可以使用 **cursor->rewind()** 回到当前结果集头部,再通过第7步或第8步进行读取; 10. 拿到下一个结果集:因为MySQL server返回的数据包可能是包含多结果集的(比如每个select/insert/...语句为一个结果集;或者call procedure返回的多结果集数据),因此用户可以通过 **cursor->next_result_set()** 跳到下一个结果集,返回值为false表示所有结果集已取完。 11. 返回第一个结果集:**cursor->first_result_set()** 可以让我们返回到所有结果集的头部,然后可以从第4步开始重新拿数据; 12. **MYSQL_STATUS_GET_RESULT**结果集每列具体数据MySQLCell:第7步中读取到的一行,由多列组成,每列结果为MySQLCell,基本使用接口有: - ``int get_data_type();`` 返回MYSQL_TYPE_LONG、MYSQL_TYPE_STRING...具体参考[mysql_types.h](/src/protocol/mysql_types.h) - ``bool is_TYPE() const;`` TYPE为int、string、ulonglong,判断是否是某种类型 - ``TYPE as_TYPE() const;`` 同上,以某种类型读出MySQLCell的数据 - ``void get_cell_nocopy(const void **data, size_t *len, int *data_type) const;`` nocopy接口 整体示例如下: ~~~cpp void task_callback(WFMySQLTask *task) { // step-1. 判断任务状态 if (task->get_state() != WFT_STATE_SUCCESS) { fprintf(stderr, "task error = %d\n", task->get_error()); return; } MySQLResultCursor cursor(task->get_resp()); bool test_first_result_set_flag = false; bool test_rewind_flag = false; // step-2. 判断回复包其他状态 if (resp->get_packet_type() == MYSQL_PACKET_ERROR) { fprintf(stderr, "ERROR. error_code=%d %s\n", task->get_resp()->get_error_code(), task->get_resp()->get_error_msg().c_str()); } begin: // step-3. 遍历结果集 do { // step-4. 判断结果集状态 if (cursor.get_cursor_status() == MYSQL_STATUS_OK) { // step-5. MYSQL_STATUS_OK结果集的基本内容 fprintf(stderr, "OK. %llu rows affected. %d warnings. insert_id=%llu.\n", cursor.get_affected_rows(), cursor.get_warnings(), cursor.get_insert_id()); } else if (cursor.get_cursor_status() == MYSQL_STATUS_GET_RESULT) { fprintf(stderr, "field_count=%u rows_count=%u ", cursor.get_field_count(), cursor.get_rows_count()); // step-6. 读取每个fields。这是个nocopy api const MySQLField *const *fields = cursor.fetch_fields(); for (int i = 0; i < cursor.get_field_count(); i++) { fprintf(stderr, "db=%s table=%s name[%s] type[%s]\n", fields[i]->get_db().c_str(), fields[i]->get_table().c_str(), fields[i]->get_name().c_str(), datatype2str(fields[i]->get_data_type())); } // step-8. 把所有行读出,也可以while (cursor.fetch_row(map/vector)) 按step-7拿每一行 std::vector> rows; cursor.fetch_all(rows); for (unsigned int j = 0; j < rows.size(); j++) { // step-12. 具体每个cell的读取 for (unsigned int i = 0; i < rows[j].size(); i++) { fprintf(stderr, "[%s][%s]", fields[i]->get_name().c_str(), datatype2str(rows[j][i].get_data_type())); // step-12. 判断具体类型is_string()和转换具体类型as_string() if (rows[j][i].is_string()) { std::string res = rows[j][i].as_string(); fprintf(stderr, "[%s]\n", res.c_str()); } else if (rows[j][i].is_int()) { fprintf(stderr, "[%d]\n", rows[j][i].as_int()); } // else if ... } } } // step-10. 拿下一个结果集 } while (cursor.next_result_set()); if (test_first_result_set_flag == false) { test_first_result_set_flag = true; // step-11. 返回第一个结果集 cursor.first_result_set(); goto begin; } if (test_rewind_flag == false) { test_rewind_flag = true; // step-9. 返回当前结果集头部 cursor.rewind(); goto begin; } return; } ~~~ # WFMySQLConnection 由于我们是高并发异步客户端,这意味着我们对一个server的连接可能会不止一个。而MySQL的事务和预处理都是带状态的,为了保证一次事务或预处理独占一个连接,用户可以使用我们封装的二级工厂WFMySQLConnection来创建任务,每个WFMySQLConnection保证独占一个连接,具体参考[WFMySQLConnection.h](/src/client/WFMySQLConnection.h)。 ### 1. WFMySQLConnection的创建与初始化 创建一个WFMySQLConnection的时候需要传入一个**id**,之后的调用内部都会由这个id和url去找到对应的那个连接。 初始化需要传入**url**,之后在这个connection上创建的任务就不需要再设置url了。 ~~~cpp class WFMySQLConnection { public: WFMySQLConnection(int id); int init(const std::string& url); ... }; ~~~ ### 2. 创建任务与关闭连接 通过 **create_query_task()** ,写入SQL请求和回调函数即可创建任务,该任务一定从这一个connection发出。 有时候我们需要手动关闭这个连接。因为当我们不再使用它的时候,这个连接会一直保持到MySQL server超时。期间如果使用同一个id和url去创建WFMySQLConnection的话就可以复用到这个连接。 因此我们建议如果不准备复用连接,应使用 **create_disconnect_task()** 创建一个任务,手动关闭这个连接。 ~~~cpp class WFMySQLConnection { public: ... WFMySQLTask *create_query_task(const std::string& query, mysql_callback_t callback); WFMySQLTask *create_disconnect_task(mysql_callback_t callback); } ~~~ WFMySQLConnection相当于一个二级工厂,我们约定任何工厂对象的生命周期无需保持到任务结束,以下代码完全合法: ~~~cpp WFMySQLConnection *conn = new WFMySQLConnection(1234); conn->init(url); auto *task = conn->create_query_task("SELECT * from table", my_callback); conn->deinit(); delete conn; task->start(); ~~~ ### 3. 注意事项 不可以无限制的产生id来生成连接对象,因为每个id会占用一小块内存,无限产生id会使内存不断增加。当一个连接使用完毕,可以不创建和运行disconnect task,而是让这个连接进入内部连接池。下一个connection通过相同的id和url初始化,会自动复用这个连接。 同一个连接上的多个任务并行启动,会得到EAGAIN错误。 如果在使用事务期间已经开始BEGIN但还没有COMMIT或ROLLBACK,且期间连接发生过中断,则连接会被框架内部自动重连,用户会在下一个task请求中拿到**ECONNRESET**错误。此时还没COMMIT的事务语句已经失效,需要重新再发一遍。 ### 4. 预处理 用户也可以通过WFMySQLConnection来做预处理**PREPARE**,因此用户可以很方便地用作**防SQL注入**。如果连接发生了重连,也会得到一个**ECONNRESET**错误。 ### 5. 完整示例 ~~~cpp WFMySQLConnection conn(1); conn.init("mysql://root@127.0.0.1/test"); // test transaction const char *query = "BEGIN;"; WFMySQLTask *t1 = conn.create_query_task(query, task_callback); query = "SELECT * FROM check_tiny FOR UPDATE;"; WFMySQLTask *t2 = conn.create_query_task(query, task_callback); query = "INSERT INTO check_tiny VALUES (8);"; WFMySQLTask *t3 = conn.create_query_task(query, task_callback); query = "COMMIT;"; WFMySQLTask *t4 = conn.create_query_task(query, task_callback); WFMySQLTask *t5 = conn.create_disconnect_task(task_callback); ((*t1) > t2 > t3 > t4 > t5).start(); ~~~ workflow-0.11.8/docs/tutorial-13-kafka_cli.md000066400000000000000000000264401476003635400207310ustar00rootroot00000000000000# 异步Kafka客户端:kafka_cli # 示例代码 [tutorial-13-kafka_cli.cc](/tutorial/tutorial-13-kafka_cli.cc) # 编译 由于支持Kafka的多种压缩方式,因此系统需要预先安装[zlib](https://github.com/madler/zlib.git),[snappy](https://github.com/google/snappy.git),[lz4(>=1.7.5)](https://github.com/lz4/lz4.git),[zstd](https://github.com/facebook/zstd.git)等第三方库。 支持CMake和Bazel两种编译方式。 CMake:执行命令make KAFKA=y 编译独立的类库(libwfkafka.a和libwfkafka.so)支持kafka协议;cd tutorial; make KAFKA=y 可以编译kafka_cli Bazel:执行bazel build kafka 编译支持kafka协议的类库;执行bazel build kafka_cli 编译kafka_cli # 关于kafka_cli 这是一个kafka client,可以完成kafka的消息生产(produce)和消息消费(fetch)。 编译时需要在tutorial目录中执行编译命令make KAFKA=y或者在项目根目录执行make KAFKA=y tutorial。 该程序从命令行读取一个kafka broker服务器地址和本次任务的类型(produce/fetch): ./kafka_cli \ [p/c] 程序会在执行完任务后自动退出,一切资源完全回收。 其中broker_url可以有多个url组成,多个url之间以,分割 - 形式如:kafka://host:port,kafka://host1:port... 或:**kafkas**://host:port,**kafkas**://host1:port代表使用SSL通信。 - port的默认值在普通TCP连接下是9092,SSL下为9093。 - "kafka://"前缀可以缺省。这时候使用默认使用TCL通信。 - 多个url,必须都采用TCP或都采用SSL。否则init函数返回-1,错误码为EINVAL。 - 如果用户在这一层有upstream选取需求,可以参考[upstream文档](../docs/about-upstream.md)。 Kafka broker_url示例: kafka://127.0.0.1/ kafka://kafka.host:9090/ kafka://10.160.23.23:9000,10.123.23.23,kafka://kafka.sogou kafkas://broker1.kafka.sogou,kafkas://broker2.kafka.sogou 错误的url示例(第一个broker为SSL,第二个broker非SSL): kafkas://broker1.kafka.sogou,broker2.kafka.sogou # 实现原理和特性 kafka client内部实现上除了压缩功能外没有依赖第三方库,同时利用了workflow的高性能,在合理的配置和环境下,每秒钟可以处理几万次Kafka请求。 在内部实现上,kafka client会把一次请求按照内部使用到的broker分拆成并行parallel任务,每个broker地址对应parallel任务中的一个子任务, 这样可以最大限度的提升效率,同时利用workflow内部对连接的复用机制使得整体的连接数控制在一个合理的范围。 如果一个broker地址下有多个topic partition,为了提高吞吐,应该创建多个client,然后按照topic partition分别创建任务独立启动。 # 创建并启动Kafka任务 首先需要创建一个WFKafkaClient对象,然后调用init函数初始化WFKafkaClient对象, ~~~cpp int init(const std::string& broker_url); int init(const std::string& broker_url, const std::string& group); ~~~ 其中broker_url是kafka broker集群的地址,格式可以参考上面的broker_url, group是消费者组的group_name,用在基于消费者组的fetch任务中,如果是produce任务或者没有使用消费者组的fetch任务,则不需要使用此接口; 用消费者组的时候,可以设置heartbeat的间隔时间,时间单位是毫秒,用于维持心跳: ~~~cpp void set_heartbeat_interval(size_t interval_ms); ~~~ 后面再通过WFKafkaClient对象创建kafka任务 ~~~cpp using kafka_callback_t = std::function; WFKafkaTask *create_kafka_task(const std::string& query, int retry_max, kafka_callback_t cb); WFKafkaTask *create_kafka_task(int retry_max, kafka_callback_t cb); ~~~ 其中query中包含此次任务的类型以及topic等属性,retry_max表示最大重试次数,cb为用户自定义的callback函数,当task执行完毕后会被调用, 接着还可以修改task的默认配置以满足实际需要,详细接口可以在[KafkaDataTypes.h](../src/protocol/KafkaDataTypes.h)中查看 ~~~cpp KafkaConfig config; config.set_client_id("workflow"); task->set_config(std::move(config)); ~~~ 支持的配置选项描述如下: 配置名 | 类型 | 默认值 | 含义 ------ | ---- | -------| ------- produce_timeout | int | 100ms | produce的超时时间 produce_msg_max_bytes | int | 1000000 bytes | 单个消息的最大长度限制 produce_msgset_cnt | int | int | 10000 | 一次通信消息集合的最大条数 produce_msgset_max_bytes | int | 1000000 bytes | 一次通信消息集合的最大长度限制 fetch_timeout | int | 100ms | fetch的超时时间 fetch_min_bytes | int | 1 byte | 一次fetch通信最小消息的长度 fetch_max_bytes | int | 50M bytes | 一次fetch通信最大消息的长度 fetch_msg_max_bytes | int | 1M bytes | 一次fetch通信单个消息的最大长度 offset_timestamp | long long int | -1 | 消费者组模式下,没有找到历史offset时,初始化的offset,-2表示最久,-1表示最新 session_timeout | int | 10s | 加入消费者组初始化时的超时时间 rebalance_timeout | int | 10s | 加入消费者组同步信息阶段的超时时间 produce_acks | int | -1 | produce任务在返回之前应确保消息成功复制的broker节点数,-1表示所有的复制broker节点 allow_auto_topic_creation | bool | true | produce时topic不存在时,是否自动创建topic broker_version | char * | NULL | 指定broker的版本号,<0.10时需要手动指定 compress_type | int | NoCompress | produce消息的压缩类型 client_id | char * | NULL | 表示client的id check_crcs | bool | false | fetch任务中是否校验消息的crc32 offset_store | int | 0 | 加入消费者组时,是否使用上次提交offset,1表示使用指定的offset,0表示优先使用上次提交 sasl_mechanisms | char * | NULL | sasl认证类型,目前支持plain和scram sasl_username | char * | NULL | sasl认证所需的username sasl_password | char * | NULL | sasl认证所需的password 最后就可以调用start接口启动kafka任务。 # produce任务 1、在创建并初始化WFKafkaClient之后,可以在query中直接指定topic等信息创建WFKafkaTask任务 使用示例如下: ~~~cpp int main(int argc, char *argv[]) { ... client = new WFKafkaClient(); client->init(url); task = client->create_kafka_task("api=fetch&topic=xxx&topic=yyy", 3, kafka_callback); ... task->start(); ... } ~~~ 2、在创建完WFKafkaTask之后,先通过调用set_key, set_value, add_header_pair等方法构建KafkaRecord, 关于KafkaRecord的更多接口,可以在[KafkaDataTypes.h](../src/protocol/KafkaDataTypes.h)中查看 然后应该通过调用add_produce_record添加KafkaRecord,关于更多接口的详细定义,可以在[WFKafkaClient.h](../src/client/WFKafkaClient.h)中查看 需要注意的是,add_produce_record的第二个参数partition,当>=0是表示指定的partition,-1表示随机指定partition或者调用自定义的kafka_partitioner_t kafka_partitioner_t可以通过set_partitioner接口设置自定义规则。 使用示例如下: ~~~cpp int main(int argc, char *argv[]) { ... WFKafkaClient *client_fetch = new WFKafkaClient(); client_fetch->init(url); task = client_fetch->create_kafka_task("api=produce&topic=xxx&topic=yyy", 3, kafka_callback); task->set_partitioner(partitioner); KafkaRecord record; record.set_key("key1", strlen("key1")); record.set_value(buf, sizeof(buf)); record.add_header_pair("hk1", 3, "hv1", 3); task->add_produce_record("workflow_test1", -1, std::move(record)); ... task->start(); ... } ~~~ 3、produce还可以使用kafka支持的4种压缩协议,通过设置配置项来实现 使用示例如下: ~~~cpp int main(int argc, char *argv[]) { ... WFKafkaClient *client_fetch = new WFKafkaClient(); client_fetch->init(url); task = client_fetch->create_kafka_task("api=produce&topic=xxx&topic=yyy", 3, kafka_callback); KafkaConfig config; config.set_compress_type(Kafka_Zstd); task->set_config(std::move(config)); KafkaRecord record; record.set_key("key1", strlen("key1")); record.set_value(buf, sizeof(buf)); record.add_header_pair("hk1", 3, "hv1", 3); task->add_produce_record("workflow_test1", -1, std::move(record)); ... task->start(); ... } ~~~ # fetch任务 fetch任务支持消费者组模式和手动模式 1、手动模式 无需指定消费者组,同时需要用户指定topic、partition和offset 使用示例如下: ~~~cpp client = new WFKafkaClient(); client->init(url); task = client->create_kafka_task("api=fetch", 3, kafka_callback); KafkaToppar toppar; toppar.set_topic_partition("workflow_test1", 0); toppar.set_offset(0); task->add_toppar(toppar); ~~~ 2、消费者组模式 在初始化client的时候需要指定消费者组的名称 使用示例如下: ~~~cpp int main(int argc, char *argv[]) { ... WFKafkaClient *client_fetch = new WFKafkaClient(); client_fetch->init(url, cgroup_name); task = client_fetch->create_kafka_task("api=fetch&topic=xxx&topic=yyy", 3, kafka_callback); ... task->start(); ... } ~~~ 3、offset的提交 在消费者组模式下,用户消费消息后,可以在callback函数中,通过创建commit任务来自动提交消费的记录,使用示例如下: ~~~cpp void kafka_callback(WFKafkaTask *task) { ... commit_task = client.create_kafka_task("api=commit", 3, kafka_callback); ... commit_task->start(); ... } ~~~ # 关于client的关闭 在消费者组模式下,client在关闭之前需要调用create_leavegroup_task创建leavegroup_task, 它会发送leavegroup协议包,如果没有启动leavegroup_task,会导致消费者组没有正确退出,触发这个组的rebalance。 # 处理kafka结果 消息的结果集的数据结构是KafkaResult,可以通过调用WFKafkaTask的get_result()接口获得, 然后调用KafkaResult的fetch_record接口可以将本次task相关的record取出来,它是一个KafkaRecord的二维vector, 第一维是topic partition,第二维是某个topic partition下对应的KafkaRecord, 在[KafkaResult.h](../src/protocol/KafkaResult.h)中可以看到KafkaResult的定义 ~~~cpp void kafka_callback(WFKafkaTask *task) { int state = task->get_state(); int error = task->get_error(); // handle error states ... protocol::KafkaResult *result = task->get_result(); result->fetch_records(records); for (auto &v : records) { for (auto &w: v) { const void *value; size_t value_len; w->get_value(&value, &value_len); printf("produce\ttopic: %s, partition: %d, status: %d, offset: %lld, val_len: %zu\n", w->get_topic(), w->get_partition(), w->get_status(), w->get_offset(), value_len); } } ... protocol::KafkaResult new_result = std::move(*task->get_result()); if (new_result.fetch_records(records)) { for (auto &v : records) { if (v.empty()) continue; for (auto &w: v) { if (fp) { const void *value; size_t value_len; w->get_value(&value, &value_len); fwrite(w->get_value(), w->get_value_len(), 1, fp); } } } } ... } ~~~ # 认证 认证信息需要在配置中设置,以sasl为例: ~~~cpp int main(int argc, char *argv[]) { ... client = new WFKafkaClient(); client->init(url); task = client->create_kafka_task("api=fetch&topic=xxx&topic=yyy", 3, kafka_callback); config.set_sasl_username("fetch"); config.set_sasl_password("fetch-secret"); config.set_sasl_mech("SCRAM-SHA-256"); task->set_config(std::move(config)); ... task->start(); ... } ~~~ workflow-0.11.8/docs/tutorial-15-name_service.md000066400000000000000000000075651476003635400214760ustar00rootroot00000000000000# 自定义命名服务策略:name_service # 示例代码 [tutorial-15-name_service.cc](/tutorial/tutorial-15-name_service.cc) # 关于name_service 本示例通过一个用户定义文本文件来指定名称服务策略。文件格式定义与系统hosts文件兼容,目前也支持指向域名。例如: ~~~ 127.0.0.1 www.myhost.com 192.168.10.10 host1 wwww.sogou.com sogou # 扩展功能,'sogou'指向'www.sogou.com' ~~~ 用户在命令行输入抓取的URL和名称服务文件来抓取目标网页。如果输入URL的域名在文件中不存在,则正常使用DNS。 # 自定义名称服务策略 所有名称服务策略,从WFNSPolic派生。其唯一需要实现的是create_router_task函数。 ~~~cpp class MyNSPolicy : public WFNSPolicy { public: WFRouterTask *create_router_task(const struct WFNSParams *params, router_callback_t callback) override; .. }; ~~~ 在这个示例里,我们并不需要引入很复杂的选取策略,只需要从一个文本文件把域名做转化。 所以,我们可以把转化结果交给全局的dns resolver,让dns resolve产生真正的路由任务。 ~~~cpp WFRouterTask *MyNSPolicy::create_router_task(const struct WFNSParams *params, router_callback_t callback) { WFDnsResolver *dns_resolver = WFGlobal::get_dns_resolver(); if (params->uri.host) { FILE *fp = fopen(this->path.c_str(), "r"); if (fp) { std::string dest = this->read_from_fp(fp, params->uri.host); if (dest.size() > 0) { /* Update the uri structure's 'host' field directly. * You can also update the 'port' field if needed. */ free(params->uri.host); params->uri.host = strdup(dest.c_str()); } fclose(fp); } } /* Simply, use the global dns resolver to create a router task. */ return dns_resolver->create_router_task(params, std::move(callback)); } ~~~ 其中read_from_fp函数从文本文件中读取信息并做转换,这个函数的实现大家可以直接看源代码。 得到转换结果之后,用新的host覆盖原params里uri的host即可。最后,调用dns resover产生路由任务。 # 注册名称服务 Workflow里,可以给每个单独的域名指定一个名称服务策略。如果一个域名找不到指定策略,则使用默认。 一般情况下,默认名称服务策略即是dns resolver。下面,我们把我定义好的策略注册到输入URL的域名上: ~~~cpp int main() { ... /* Create an naming policy. */ MyNSPolicy *policy = new MyNSPolicy(filename); /* Get the global name service object.*/ WFNameService *ns = WFGlobal::get_name_service(); /* Add the our name with policy to global name service. * You can add mutilply names with one policy object. */ ns->add_policy(name, policy); ... } ~~~ 其中,name为URL里的域名。这样的话,这个域名下的所有URL,都将使用我们自定义的名称服务策略了。 在程序退出之前,我们也需要把这个策略从全局名称服务中删除,防止内存泄漏: ~~~cpp int main() { ... /* clean up */ ns->del_policy(name); delete policy; return 0; } ~~~ # 设置默认名称服务策略 在这个例子中,其实我们并没有修改默认名称服务策略。有些情况下,我们可能想让所有的host都使用这个名称服务策略。 这种情况,我们也可以修改默认的策略,让这个策略对所有的host都生效。只需要调用全局名称服务的set_default_policy函数: ~~~cpp int main() { MyNSPolicy *policy = new MyNSPolicy(filename); WFNameService *ns = WFGlobal::get_name_service(); ns->set_default_policy(policy); ... /* Reset default policy to dns resolver and clean up */ ns->set_default_policy(WFGlobal::get_dns_resolver()); delete policy; return 0; } ~~~ workflow-0.11.8/docs/tutorial-17-dns_cli.md000066400000000000000000000165461476003635400204520ustar00rootroot00000000000000# 使用workflow请求DNS 作为一款优秀的异步编程框架,workflow帮助用户处理了大量的细节,其中就包括域名解析,因此在大部分情况下,用户无需关心如何请求DNS服务。正如workflow中的其他模块一样,DNS解析模块设计的同样完备而优雅,若恰好需要实现一些域名解析任务,workflow中的WFDnsClient和WFDnsTask无疑是一个绝佳的选择。 [about-dns](about-dns.md)中介绍了如何配置DNS相关参数,而本篇文档的重点在于介绍如何创建DNS任务以及获取解析结果。 [tutorial-17-dns_cli.cc](/tutorial/tutorial-17-dns_cli.cc) ## 使用WFDnsClient创建任务 WFDnsClient是经过封装的高级接口,其行为类似于系统提供的`resolv.conf`配置文件,帮助用户代理了重试、search列表拼接、server轮换等功能,使用起来非常简单。WFDnsClient的初始化方式有以下几种情况,当函数返回0时表示初始化成功 - 使用一个DNS IPv4地址初始化,下述两种写法等价 ```cpp client.init("8.8.8.8"); // or client.init("dns://8.8.8.8/"); ``` - 使用一个DNS IPv6地址初始化 ```cpp client.init("[2402:4e00::]:53"); ``` - 使用DNS over TLS(DoT)地址初始化,默认端口号为853 ```cpp client.init("dnss://120.53.53.53/"); ``` - 使用多个由逗号分隔的DNS地址初始化 ```cpp client.init("dns://8.8.8.8/,119.29.29.29"); ``` - 显式指定重试策略的初始化,示例代码等价于下述`resolv.conf`描述的策略 ``` nameserver 8.8.8.8 search sogou.com tencent.com options nodts:1 attempts:2 rotate ``` ```cpp client.init("8.8.8.8", "sogou.com,tencent.com", 1, 2, true); ``` 使用WFDnsClient创建的任务默认为`DNS_TYPE_A`、`DNS_CLASS_IN`类型的解析请求,且已经设置了递归解析的选项,即`task->get_req()->set_rd(1)`。了解了`WFDnsClient`的初始化的方式,仅需八行即可发起一个DNS解析任务 ```cpp int main() { WFDnsClient client; client.init("8.8.8.8"); WFDnsTask *task = client.create_dns_task("www.sogou.com", dns_callback); task->start(); pause(); client.deinit(); return 0; } ``` ## 使用工厂函数创建任务 若不需要WFDnsClient提供的额外功能,或想自行组织重试策略,可使用工厂函数创建任务。 使用工厂函数创建任务时,可以在`url path`中指定要被解析的域名,工厂函数创建的任务默认为`DNS_TYPE_A`、`DNS_CLASS_IN`类型的解析请求,创建后可以通过`set_question_type`和`set_question_class`修改,例如 ```cpp std::string url = "dns://8.8.8.8/www.sogou.com"; WFDnsTask *task = WFTaskFactory::create_dns_task(url, 0, dns_callback); protocol::DnsRequest *req = task->get_req(); req->set_rd(1); req->set_question_type(DNS_TYPE_AAAA); req->set_question_class(DNS_CLASS_IN); ``` 若不在创建任务时指定要被解析的域名(此时默认的任务是对根域名`.`进行解析),在创建任务后可以使用`set_question`函数设置域名等参数,例如 ```cpp std::string url = "dns://8.8.8.8/"; WFDnsTask *task = WFTaskFactory::create_dns_task(url, 0, dns_callback); protocol::DnsRequest *req = task->get_req(); req->set_rd(1); req->set_question("www.zhihu.com", DNS_TYPE_AAAA, DNS_CLASS_IN); ``` ## 借助工具获取结果 一次成功的DNS请求会获得完整的DNS请求结果,有两种简便的接口可以从结果中获取信息 ### DnsUtil::getaddrinfo 该函数类似于系统的`getaddrinfo`函数,调用成功时返回零并成功获得一组`struct addrinfo`,调用失败时返回`EAI_*`类型的错误码。对该函数的成功调用最终**都应该**使用`DnsUtil::freeaddrinfo`释放资源 ```cpp void dns_callback(WFDnsTask *task) { // ignore handle error states struct addrinfo *res; protocol::DnsResponse *resp = task->get_resp(); int ret = protocol::DnsUtil::getaddrinfo(resp, 80, &res); // ignore check ret == 0 char ip_str[INET6_ADDRSTRLEN + 1] = { 0 }; for (struct addrinfo *p = res; p; p = p->ai_next) { void *addr = nullptr; if (p->ai_family == AF_INET) addr = &((struct sockaddr_in *)p->ai_addr)->sin_addr; else if (p->ai_family == AF_INET6) addr = &((struct sockaddr_in6 *)p->ai_addr)->sin6_addr; if (addr) { inet_ntop(p->ai_family, addr, ip_str, p->ai_addrlen); printf("ip:%s\n", ip_str); } } protocol::DnsUtil::freeaddrinfo(res); } ``` ### DnsResultCursor `DnsUtil::getaddrinfo`一般用于获取`IPv4`、`IPv6`地址,而使用DnsResultCursor可以完整地遍历DNS结果。DNS解析结果分为answer、authority、additional三个区域,一般情况下主要内容位于answer区域,此处分别判断每个区域是否有内容,并调用`show_result`以逐一展示结果 ```cpp void dns_callback(WFDnsTask *task) { // ignore handle error states protocol::DnsResponse *resp = task->get_resp(); protocol::DnsResultCursor cursor(resp); if(resp->get_ancount() > 0) { cursor.reset_answer_cursor(); printf(";; ANSWER SECTION:\n"); show_result(cursor); } if(resp->get_nscount() > 0) { cursor.reset_authority_cursor(); printf(";; AUTHORITY SECTION\n"); show_result(cursor); } if(resp->get_arcount() > 0) { cursor.reset_additional_cursor(); printf(";; ADDITIONAL SECTION\n"); show_result(cursor); } } ``` 根据请求类型不同,结果中包含的数据可以多种多样,常见的有 - DNS_TYPE_A: IPv4类型的地址 - DNS_TYPE_AAAA: IPv6类型的地址 - DNS_TYPE_NS: 该域名的权威DNS服务器 - DNS_TYPE_CNAME: 该域名的权威名称 ```cpp void show_result(protocol::DnsResultCursor &cursor) { char information[1024]; const char *info; struct dns_record *record; struct dns_record_soa *soa; struct dns_record_srv *srv; struct dns_record_mx *mx; while(cursor.next(&record)) { switch (record->type) { case DNS_TYPE_A: info = inet_ntop(AF_INET, record->rdata, information, 64); break; case DNS_TYPE_AAAA: info = inet_ntop(AF_INET6, record->rdata, information, 64); break; case DNS_TYPE_NS: case DNS_TYPE_CNAME: case DNS_TYPE_PTR: info = (const char *)(record->rdata); break; case DNS_TYPE_SOA: soa = (struct dns_record_soa *)(record->rdata); sprintf(information, "%s %s %u %d %d %d %u", soa->mname, soa->rname, soa->serial, soa->refresh, soa->retry, soa->expire, soa->minimum ); info = information; break; case DNS_TYPE_SRV: srv = (struct dns_record_srv *)(record->rdata); sprintf(information, "%u %u %u %s", srv->priority, srv->weight, srv->port, srv->target ); info = information; break; case DNS_TYPE_MX: mx = (struct dns_record_mx *)(record->rdata); sprintf(information, "%d %s", mx->preference, mx->exchange); info = information; break; default: info = "Unknown"; } printf("%s\t%d\t%s\t%s\t%s\n", record->name, record->ttl, dns_class2str(record->rclass), dns_type2str(record->type), info ); } printf("\n"); } ``` workflow-0.11.8/docs/tutorial-18-redis_subscriber.md000066400000000000000000000110401476003635400223510ustar00rootroot00000000000000# Redis订阅模式 ## 示例代码 [tutorial-18-redis_subscriber.cc](/tutorial/tutorial-18-redis_subscriber.cc) ## 创建订阅客户端和任务 在Workflow中,一个客户端网络任务通常是向服务端发出一个请求并接收一个回复,而Redis订阅任务不同,它会先发出一个订阅请求,然后源源不断地接收服务端推送过来的消息,在这个过程中,客户端还可以新增或取消channels、patterns。 用于实现Redis订阅功能的任务是`WFRedisSubscribeTask`,与普通的Redis任务不同,它不从任务工厂产生,而是需要使用`WFRedisSubscriber`来创建。例如 ```cpp WFRedisSubscriber suber; if (suber.init(url) != 0) { std::cerr << "Subscriber init failed " << strerror(errno) << std::endl; exit(1); } // ... WFRedisSubscribeTask *task; task = suber.create_subscribe_task(channels, extract, callback); task->set_watch_timeout(1000000); // 1000秒 task->start(); // 这里可以使用task的相关接口改变订阅内容 // ... task->release(); suber.deinit(); ``` 初始化`WFRedisSubscriber`需要使用`Redis URL`,这与普通Redis任务相同,不再赘述。创建订阅任务时,需要提供三个参数 - channels/patterns: 一个或多个被订阅的channel(subscribe)或pattern(psubscribe) - extract: 收到服务端推送消息时的处理函数 - callback: 任务结束后的回调函数 这个例子中为`watch_timeout`设置了一个很长的时间,若这个时间较短,且服务端长时间未推送消息,则连接会因为超时而断开,订阅任务也会直接失败,请根据实际情况合理设置。 当任务处理完成后,需要通过`task->release()`来释放这个任务,这也是与其他任务的一个不同之处。 ## 处理订阅消息 服务端推送的消息由创建任务时指定的`extract`函数处理。后续描述中,subscribe对应channel,psubscribe对应pattern。 1. 服务端推送的消息格式是具有三个元素的数组,第一个元素是字符串"message"或"pmessage",第二个元素是该消息的channel或pattern的名称,第三个元素是消息的内容。 2. subscribe或psubscribe请求的回复是具有三个元素的数组,第一个元素是字符串"subscribe"或"psubscribe",第二个元素是channel或pattern的名称,第三个元素是当前通过subscribe或psubscribe命令已经订阅了多少个channel或pattern,是一个整数。如果一个请求订阅了多个channel或pattern,会有多个回复。 3. unsubscribe或punsubscribe请求的回复是具有三个元素的数组,格式与订阅命令相似。当取消订阅但不指定channel或pattern时,表示取消所有该类型的订阅,对于所有已经订阅的channel或pattern,返回一个回复消息。若当前类型未订阅任何channel或pattern,则返回一个消息,其中名称部分为nil。 更多详情可参阅redis文档。 处理消息的一个示例如下,简单地将内容打印到标准输出 ```cpp void extract(WFRedisSubscribeTask *task) { auto *resp = task->get_resp(); protocol::RedisValue value; resp->get_result(value); if (value.is_array()) { for (size_t i = 0; i < value.arr_size(); i++) { if (value[i].is_string()) std::cout << value[i].string_value(); else if (value[i].is_int()) std::cout << value[i].int_value(); else if (value[i].is_nil()) std::cout << "nil"; else std::cout << "Unexpected value in array!"; std::cout << "\n"; } } else std::cout << "Unexpected value!\n"; } ``` ## 改变订阅内容 在任务过程中,可以通过下述接口新增或取消订阅,注意在带有channels或patterns参数的接口中,请勿传入空数组。 ```cpp // ... task->start(); // 新增订阅一组channels task->subscribe(channels); // 取消订阅一组channels task->unsubscribe(channels); // 取消订阅所有channels task->unsubscribe(); // 新增订阅一组patterns task->psubscribe(patterns); // 取消订阅一组patterns task->punsubscribe(patterns); // 取消订阅所有patterns task->punsubscribe(); task->release(); ``` 当所有channels和patterns都被取消订阅后,任务会直接结束,此后不能再新增订阅,请注意该细节。也可以直接通过`task->quit()`来主动结束任务。 此外,订阅模式下可以通过`task->ping()`或`task->ping(message)`向Redis服务器发起`ping`请求。当任务设置了较小的`watch_timeout`,但服务端可能长时间没有消息推送时,通过定时发出`ping`请求可以令服务端推送`pong`响应,此时任务便不会因为超时而失败。 workflow-0.11.8/docs/tutorial-19-dns_server.md000066400000000000000000000056131476003635400212040ustar00rootroot00000000000000# 使用workflow实现DNS服务器 前述文档已经讲解了使用workflow实现服务器的方法,workflow框架贴心地为用户处理了底层逻辑和各种细节,因此本文档主要介绍如何组装DNS消息。 [tutorial-19-dns_server.cc](/tutorial/tutorial-19-dns_server.cc) DNS协议内容中包含三个section,有`DNS_ANSWER_SECTION`、`DNS_AUTHORITY_SECTION`、`DNS_ADDITIONAL_SECTION`,每个section中可包含零或多条资源记录`Resource record`。目前`protocol::DnsResponse`支持添加的资源记录类型有`DNS_TYPE_A`、`DNS_TYPE_AAAA`、`DNS_TYPE_CNAME`、`DNS_TYPE_PTR`、`DNS_TYPE_SOA`、`DNS_TYPE_SRV`、`DNS_TYPE_MX`,其接口如下所示。 ```cpp int add_a_record(int section, const char *name, uint16_t rclass, uint32_t ttl, const void *data); int add_aaaa_record(int section, const char *name, uint16_t rclass, uint32_t ttl, const void *data); int add_ns_record(int section, const char *name, uint16_t rclass, uint32_t ttl, const char *data); int add_cname_record(int section, const char *name, uint16_t rclass, uint32_t ttl, const char *data); int add_ptr_record(int section, const char *name, uint16_t rclass, uint32_t ttl, const char *data); int add_soa_record(int section, const char *name, uint16_t rclass, uint32_t ttl, const char *mname, const char *rname, uint32_t serial, int32_t refresh, int32_t retry, int32_t expire, uint32_t minimum); int add_srv_record(int section, const char *name, uint16_t rclass, uint32_t ttl, uint16_t priority, uint16_t weight, uint16_t port, const char *target); int add_mx_record(int section, const char *name, uint16_t rclass, uint32_t ttl, int16_t preference, const char *exchange); int add_raw_record(int section, const char *name, uint16_t type, uint16_t rclass, uint32_t ttl, const void *data, uint16_t dlen); ``` 例如要添加一条AAAA记录,可使用下述方式实现 ```cpp struct in6_addr addr; inet_pton(AF_INET6, "1234:5678:9abc:def0::", (void *)&addr); resp->add_aaaa_record(DNS_ANSWER_SECTION, name.c_str(), DNS_CLASS_IN, 600, &addr); ``` 对于未支持的资源记录类型,可通过`add_raw_record`接口添加,例如要添加一条TXT记录,可使用下述方式实现 ```cpp const char *raw_txt_data = "\x0dmy dns server\x0fyour dns server"; uint16_t data_len = 30; resp->add_raw_record(DNS_ANSWER_SECTION, name.c_str(), DNS_TYPE_TXT, DNS_CLASS_IN, 1200, raw_txt_data, data_len); ``` 注意,默认情况下`WFDnsServer`会启动一个UDP服务,若需要启动TCP服务,可通过修改WFServerParams中的transport_type字段为`TT_TCP`来实现。DNS客户端通常会优先使用UDP协议发起请求,当要回复的消息过大时,可仅添加部分资源记录,并通过`resp->set_tc(1)`设置截断标记,指示客户端可使用TCP协议重新请求。 workflow-0.11.8/docs/xmake.md000066400000000000000000000025171476003635400160470ustar00rootroot00000000000000# 编译 ``` // 编译workflow库 xmake // 编译test xmake -g test // 运行test文件 xmake run -g test // 编译tutorial xmake -g tutorial // 编译benchmark xmake -g benchmark ``` ## 运行 `xmake run -h` 可以查看运行哪些target 选择一个target即可运行 比如 ``` xmake run tutorial-06-parallel_wget ``` ## 安装 ``` sudo xmake install ``` ## 切换编译静态库/动态库 ``` // 编译静态库 xmake f -k static xmake -r ``` ``` // 编译动态库 xmake f -k shared xmake -r ``` `tips : -r 代表 -rebuild` ## 进行定制化裁剪 `xmake f --help` 可查看我们定制的option ``` Command options (Project Configuration): --workflow_inc=WORKFLOW_INC workflow inc (default: /media/psf/pro/workflow/_include) --upstream=[y|n] build upstream component (default: y) --consul=[y|n] build consul component --workflow_lib=WORKFLOW_LIB workflow lib (default: /media/psf/pro/workflow/_lib) --redis=[y|n] build redis component (default: y) --kafka=[y|n] build kafka component --mysql=[y|n] build mysql component (default: y) ``` 你可以通过如下命令来进行各个模块的裁剪或集成 ``` xmake f --redis=n --kafka=y --mysql=n xmake -r ``` workflow-0.11.8/src/000077500000000000000000000000001476003635400142525ustar00rootroot00000000000000workflow-0.11.8/src/CMakeLists.txt000066400000000000000000000103341476003635400170130ustar00rootroot00000000000000cmake_minimum_required(VERSION 3.6) if(ANDROID) include_directories(${OPENSSL_INCLUDE_DIR}) link_directories(${OPENSSL_LINK_DIR}) else() find_package(OpenSSL REQUIRED) endif () include_directories(${OPENSSL_INCLUDE_DIR} ${INC_DIR}/workflow) if (KAFKA STREQUAL "y") find_path(SNAPPY_INCLUDE_PATH NAMES snappy.h) find_library(SNAPPY_LIB NAMES snappy) if ((NOT SNAPPY_INCLUDE_PATH) OR (NOT SNAPPY_LIB)) message(FATAL_ERROR "Fail to find snappy with KAFKA=y") endif () include_directories(${SNAPPY_INCLUDE_PATH}) endif () add_subdirectory(kernel) add_subdirectory(util) add_subdirectory(manager) add_subdirectory(protocol) add_subdirectory(factory) add_subdirectory(nameservice) add_subdirectory(server) add_subdirectory(client) add_dependencies(kernel LINK_HEADERS) add_dependencies(util LINK_HEADERS) add_dependencies(manager LINK_HEADERS) add_dependencies(protocol LINK_HEADERS) add_dependencies(factory LINK_HEADERS) add_dependencies(nameservice LINK_HEADERS) add_dependencies(server LINK_HEADERS) add_dependencies(client LINK_HEADERS) set(STATIC_LIB_NAME ${PROJECT_NAME}-static) set(SHARED_LIB_NAME ${PROJECT_NAME}-shared) add_library( ${STATIC_LIB_NAME} STATIC $ $ $ $ $ $ $ $ ) add_library( ${SHARED_LIB_NAME} SHARED $ $ $ $ $ $ $ $ ) if(ANDROID) target_link_libraries(${SHARED_LIB_NAME} ssl crypto c) target_link_libraries(${STATIC_LIB_NAME} ssl crypto c) else() target_link_libraries(${SHARED_LIB_NAME} OpenSSL::SSL OpenSSL::Crypto pthread) target_link_libraries(${STATIC_LIB_NAME} OpenSSL::SSL OpenSSL::Crypto pthread) endif () set_target_properties(${STATIC_LIB_NAME} PROPERTIES OUTPUT_NAME ${PROJECT_NAME}) set_target_properties(${SHARED_LIB_NAME} PROPERTIES OUTPUT_NAME ${PROJECT_NAME} VERSION ${PROJECT_VERSION} SOVERSION ${PROJECT_VERSION_MAJOR}) if (KAFKA STREQUAL "y") add_dependencies(client_kafka LINK_HEADERS) add_dependencies(util_kafka LINK_HEADERS) add_dependencies(protocol_kafka LINK_HEADERS) add_dependencies(factory_kafka LINK_HEADERS) set(KAFKA_STATIC_LIB_NAME "wfkafka-static") add_library( ${KAFKA_STATIC_LIB_NAME} STATIC $ $ $ $ ) set_target_properties(${KAFKA_STATIC_LIB_NAME} PROPERTIES OUTPUT_NAME "wfkafka") set(KAFKA_SHARED_LIB_NAME "wfkafka-shared") add_library( ${KAFKA_SHARED_LIB_NAME} SHARED $ $ $ $ ) if (APPLE) target_link_libraries(${KAFKA_SHARED_LIB_NAME} ${SHARED_LIB_NAME} z lz4 zstd ${SNAPPY_LIB}) else () target_link_libraries(${KAFKA_SHARED_LIB_NAME} ${SHARED_LIB_NAME}) endif () set_target_properties(${KAFKA_SHARED_LIB_NAME} PROPERTIES OUTPUT_NAME "wfkafka" VERSION ${PROJECT_VERSION} SOVERSION ${PROJECT_VERSION_MAJOR}) endif () install( TARGETS ${STATIC_LIB_NAME} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} COMPONENT devel ) install( TARGETS ${SHARED_LIB_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} COMPONENT devel ) if (KAFKA STREQUAL "y") install( TARGETS ${KAFKA_STATIC_LIB_NAME} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} COMPONENT devel ) install( TARGETS ${KAFKA_SHARED_LIB_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} COMPONENT devel ) endif () target_include_directories(${STATIC_LIB_NAME} BEFORE PUBLIC "$" "$") target_include_directories(${SHARED_LIB_NAME} BEFORE PUBLIC "$" "$") if (KAFKA STREQUAL "y") target_include_directories(${KAFKA_STATIC_LIB_NAME} BEFORE PUBLIC "$" "$") target_include_directories(${KAFKA_SHARED_LIB_NAME} BEFORE PUBLIC "$" "$") endif () workflow-0.11.8/src/client/000077500000000000000000000000001476003635400155305ustar00rootroot00000000000000workflow-0.11.8/src/client/CMakeLists.txt000066400000000000000000000007341476003635400202740ustar00rootroot00000000000000cmake_minimum_required(VERSION 3.6) project(client) set(SRC WFDnsClient.cc ) if (NOT REDIS STREQUAL "n") set(SRC ${SRC} WFRedisSubscriber.cc ) endif () if (NOT MYSQL STREQUAL "n") set(SRC ${SRC} WFMySQLConnection.cc ) endif () if (NOT CONSUL STREQUAL "n") set(SRC ${SRC} WFConsulClient.cc ) endif () add_library(${PROJECT_NAME} OBJECT ${SRC}) if (KAFKA STREQUAL "y") set(SRC WFKafkaClient.cc ) add_library("client_kafka" OBJECT ${SRC}) endif () workflow-0.11.8/src/client/WFConsulClient.cc000066400000000000000000000670331476003635400207070ustar00rootroot00000000000000/* Copyright (c) 2022 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wang Zhenpeng (wangzhenpeng@sogou-inc.com) */ #include #include #include #include #include #include #include "json_parser.h" #include "StringUtil.h" #include "URIParser.h" #include "HttpUtil.h" #include "WFConsulClient.h" using namespace protocol; WFConsulTask::WFConsulTask(const std::string& proxy_url, const std::string& service_namespace, const std::string& service_name, const std::string& service_id, int retry_max, consul_callback_t&& cb) : proxy_url(proxy_url), callback(std::move(cb)) { this->service.service_name = service_name; this->service.service_namespace = service_namespace; this->service.service_id = service_id; this->api_type = CONSUL_API_TYPE_UNKNOWN; this->retry_max = retry_max; this->finish = false; this->consul_index = 0; } void WFConsulTask::set_service(const struct protocol::ConsulService *service) { this->service.tags = service->tags; this->service.meta = service->meta; this->service.tag_override = service->tag_override; this->service.service_address = service->service_address; this->service.lan = service->lan; this->service.lan_ipv4 = service->lan_ipv4; this->service.lan_ipv6 = service->lan_ipv6; this->service.virtual_address = service->virtual_address; this->service.wan = service->wan; this->service.wan_ipv4 = service->wan_ipv4; this->service.wan_ipv6 = service->wan_ipv6; } static bool parse_discover_result(const json_value_t *root, std::vector& result); static bool parse_list_service_result(const json_value_t *root, std::vector& result); bool WFConsulTask::get_discover_result( std::vector& result) { json_value_t *root; int errno_bak; bool ret; if (this->api_type != CONSUL_API_TYPE_DISCOVER) { errno = EPERM; return false; } errno_bak = errno; errno = EBADMSG; std::string body = HttpUtil::decode_chunked_body(&this->http_resp); root = json_value_parse(body.c_str()); if (!root) return false; ret = parse_discover_result(root, result); json_value_destroy(root); if (ret) errno = errno_bak; return ret; } bool WFConsulTask::get_list_service_result( std::vector& result) { json_value_t *root; int errno_bak; bool ret; if (this->api_type != CONSUL_API_TYPE_LIST_SERVICE) { errno = EPERM; return false; } errno_bak = errno; errno = EBADMSG; std::string body = HttpUtil::decode_chunked_body(&this->http_resp); root = json_value_parse(body.c_str()); if (!root) return false; ret = parse_list_service_result(root, result); json_value_destroy(root); if (ret) errno = errno_bak; return ret; } void WFConsulTask::dispatch() { WFHttpTask *task; if (this->finish) { this->subtask_done(); return; } switch(this->api_type) { case CONSUL_API_TYPE_DISCOVER: task = create_discover_task(); break; case CONSUL_API_TYPE_LIST_SERVICE: task = create_list_service_task(); break; case CONSUL_API_TYPE_DEREGISTER: task = create_deregister_task(); break; case CONSUL_API_TYPE_REGISTER: task = create_register_task(); if (task) break; if (1) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; } else { default: this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_CONSUL_API_UNKNOWN; } this->finish = true; this->subtask_done(); return; } series_of(this)->push_front(this); series_of(this)->push_front(task); this->subtask_done(); } SubTask *WFConsulTask::done() { SeriesWork *series = series_of(this); if (finish) { if (this->callback) this->callback(this); delete this; } return series->pop(); } static std::string convert_time_to_str(int milliseconds) { std::string str_time; int seconds = milliseconds / 1000; if (seconds >= 180) str_time = std::to_string(seconds / 60) + "m"; else str_time = std::to_string(seconds) + "s"; return str_time; } std::string WFConsulTask::generate_discover_request() { std::string url = this->proxy_url; url += "/v1/health/service/" + this->service.service_name; url += "?dc=" + this->config.get_datacenter(); url += "&ns=" + this->service.service_namespace; std::string passing = this->config.get_passing() ? "true" : "false"; url += "&passing=" + passing; url += "&token=" + this->config.get_token(); url += "&filter=" + this->config.get_filter_expr(); //consul blocking query if (this->config.blocking_query()) { url += "&index=" + std::to_string(this->get_consul_index()); url += "&wait=" + convert_time_to_str(this->config.get_wait_ttl()); } return url; } WFHttpTask *WFConsulTask::create_discover_task() { std::string url = generate_discover_request(); WFHttpTask *task = WFTaskFactory::create_http_task(url, 0, this->retry_max, discover_callback); HttpRequest *req = task->get_req(); req->add_header_pair("Content-Type", "application/json"); task->user_data = this; return task; } WFHttpTask *WFConsulTask::create_list_service_task() { std::string url = this->proxy_url; url += "/v1/catalog/services?token=" + this->config.get_token(); url += "&dc=" + this->config.get_datacenter(); url += "&ns=" + this->service.service_namespace; WFHttpTask *task = WFTaskFactory::create_http_task(url, 0, this->retry_max, list_service_callback); HttpRequest *req = task->get_req(); req->add_header_pair("Content-Type", "application/json"); task->user_data = this; return task; } static void print_json_value(const json_value_t *val, int depth, std::string& json_str); static bool create_register_request(const json_value_t *root, const struct ConsulService *service, const ConsulConfig& config); WFHttpTask *WFConsulTask::create_register_task() { std::string payload; std::string url = this->proxy_url; url += "/v1/agent/service/register?replace-existing-checks="; url += this->config.get_replace_checks() ? "true" : "false"; WFHttpTask *task = WFTaskFactory::create_http_task(url, 0, this->retry_max, register_callback); HttpRequest *req = task->get_req(); req->set_method(HttpMethodPut); req->add_header_pair("Content-Type", "application/json"); if (!this->config.get_token().empty()) req->add_header_pair("X-Consul-Token", this->config.get_token()); json_value_t *root = json_value_create(JSON_VALUE_OBJECT); if (root) { if (create_register_request(root, &this->service, this->config)) print_json_value(root, 0, payload); json_value_destroy(root); if (!payload.empty() && req->append_output_body(payload)) { task->user_data = this; return task; } } task->dismiss(); return NULL; } WFHttpTask *WFConsulTask::create_deregister_task() { std::string url = this->proxy_url; url += "/v1/agent/service/deregister/" + this->service.service_id; url += "?ns=" + this->service.service_namespace; WFHttpTask *task = WFTaskFactory::create_http_task(url, 0, this->retry_max, register_callback); HttpRequest *req = task->get_req(); req->set_method(HttpMethodPut); req->add_header_pair("Content-Type", "application/json"); std::string token = this->config.get_token(); if (!token.empty()) req->add_header_pair("X-Consul-Token", token); task->user_data = this; return task; } bool WFConsulTask::check_task_result(WFHttpTask *task, WFConsulTask *consul_task) { if (task->get_state() != WFT_STATE_SUCCESS) { consul_task->state = task->get_state(); consul_task->error = task->get_error(); return false; } protocol::HttpResponse *resp = task->get_resp(); if (strcmp(resp->get_status_code(), "200") != 0) { consul_task->state = WFT_STATE_TASK_ERROR; consul_task->error = WFT_ERR_CONSUL_CHECK_RESPONSE_FAILED; return false; } return true; } long long WFConsulTask::get_consul_index(HttpResponse *resp) { long long consul_index = 0; // get consul-index from http header protocol::HttpHeaderCursor cursor(resp); std::string consul_index_str; if (cursor.find("X-Consul-Index", consul_index_str)) { consul_index = strtoll(consul_index_str.c_str(), NULL, 10); if (consul_index < 0) consul_index = 0; } return consul_index; } void WFConsulTask::discover_callback(WFHttpTask *task) { WFConsulTask *t = (WFConsulTask*)task->user_data; if (WFConsulTask::check_task_result(task, t)) { protocol::HttpResponse *resp = task->get_resp(); long long consul_index = t->get_consul_index(resp); long long last_consul_index = t->get_consul_index(); t->set_consul_index(consul_index < last_consul_index ? 0 : consul_index); t->state = task->get_state(); } t->http_resp = std::move(*task->get_resp()); t->finish = true; } void WFConsulTask::list_service_callback(WFHttpTask *task) { WFConsulTask *t = (WFConsulTask*)task->user_data; if (WFConsulTask::check_task_result(task, t)) t->state = task->get_state(); t->http_resp = std::move(*task->get_resp()); t->finish = true; } void WFConsulTask::register_callback(WFHttpTask *task) { WFConsulTask *t = (WFConsulTask *)task->user_data; if (WFConsulTask::check_task_result(task, t)) t->state = task->get_state(); t->http_resp = std::move(*task->get_resp()); t->finish = true; } int WFConsulClient::init(const std::string& proxy_url, ConsulConfig config) { ParsedURI uri; if (URIParser::parse(proxy_url, uri) >= 0) { this->proxy_url = uri.scheme; this->proxy_url += "://"; this->proxy_url += uri.host; if (uri.port) { this->proxy_url += ":"; this->proxy_url += uri.port; } this->config = std::move(config); return 0; } else if (uri.state == URI_STATE_INVALID) errno = EINVAL; return -1; } int WFConsulClient::init(const std::string& proxy_url) { return this->init(proxy_url, ConsulConfig()); } WFConsulTask *WFConsulClient::create_discover_task( const std::string& service_namespace, const std::string& service_name, int retry_max, consul_callback_t cb) { WFConsulTask *task = new WFConsulTask(this->proxy_url, service_namespace, service_name, "", retry_max, std::move(cb)); task->set_api_type(CONSUL_API_TYPE_DISCOVER); task->set_config(this->config); return task; } WFConsulTask *WFConsulClient::create_list_service_task( const std::string& service_namespace, int retry_max, consul_callback_t cb) { WFConsulTask *task = new WFConsulTask(this->proxy_url, service_namespace, "", "", retry_max, std::move(cb)); task->set_api_type(CONSUL_API_TYPE_LIST_SERVICE); task->set_config(this->config); return task; } WFConsulTask *WFConsulClient::create_register_task( const std::string& service_namespace, const std::string& service_name, const std::string& service_id, int retry_max, consul_callback_t cb) { WFConsulTask *task = new WFConsulTask(this->proxy_url, service_namespace, service_name, service_id, retry_max, std::move(cb)); task->set_api_type(CONSUL_API_TYPE_REGISTER); task->set_config(this->config); return task; } WFConsulTask *WFConsulClient::create_deregister_task( const std::string& service_namespace, const std::string& service_id, int retry_max, consul_callback_t cb) { WFConsulTask *task = new WFConsulTask(this->proxy_url, service_namespace, "", service_id, retry_max, std::move(cb)); task->set_api_type(CONSUL_API_TYPE_DEREGISTER); task->set_config(this->config); return task; } static bool create_tagged_address(const ConsulAddress& consul_address, const std::string& name, json_object_t *tagged_obj) { if (consul_address.first.empty()) return true; const json_value_t *val = json_object_append(tagged_obj, name.c_str(), JSON_VALUE_OBJECT); if (!val) return false; json_object_t *obj = json_value_object(val); if (!json_object_append(obj, "Address", JSON_VALUE_STRING, consul_address.first.c_str())) return false; if (!json_object_append(obj, "Port", JSON_VALUE_NUMBER, (double)consul_address.second)) return false; return true; } static bool create_health_check(const ConsulConfig& config, json_object_t *obj) { const json_value_t *val; std::string str; if (!config.get_health_check()) return true; val = json_object_append(obj, "Check", JSON_VALUE_OBJECT); if (!val) return false; obj = json_value_object(val); str = config.get_check_name(); if (!json_object_append(obj, "Name", JSON_VALUE_STRING, str.c_str())) return false; str = config.get_check_notes(); if (!json_object_append(obj, "Notes", JSON_VALUE_STRING, str.c_str())) return false; str = config.get_check_http_url(); if (!str.empty()) { if (!json_object_append(obj, "HTTP", JSON_VALUE_STRING, str.c_str())) return false; str = config.get_check_http_method(); if (!json_object_append(obj, "Method", JSON_VALUE_STRING, str.c_str())) return false; str = config.get_http_body(); if (!json_object_append(obj, "Body", JSON_VALUE_STRING, str.c_str())) return false; val = json_object_append(obj, "Header", JSON_VALUE_OBJECT); if (!val) return false; json_object_t *header_obj = json_value_object(val); for (const auto& header : *config.get_http_headers()) { val = json_object_append(header_obj, header.first.c_str(), JSON_VALUE_ARRAY); if (!val) return false; json_array_t *arr = json_value_array(val); for (const auto& value : header.second) { if (!json_array_append(arr, JSON_VALUE_STRING, value.c_str())) return false; } } } str = config.get_check_tcp(); if (!str.empty()) { if (!json_object_append(obj, "TCP", JSON_VALUE_STRING, str.c_str())) return false; } str = config.get_initial_status(); if (!json_object_append(obj, "Status", JSON_VALUE_STRING, str.c_str())) return false; str = convert_time_to_str(config.get_auto_deregister_time()); if (!json_object_append(obj, "DeregisterCriticalServiceAfter", JSON_VALUE_STRING, str.c_str())) return false; str = convert_time_to_str(config.get_check_interval()); if (!json_object_append(obj, "Interval", JSON_VALUE_STRING, str.c_str())) return false; str = convert_time_to_str(config.get_check_timeout()); if (!json_object_append(obj, "Timeout", JSON_VALUE_STRING, str.c_str())) return false; if (!json_object_append(obj, "SuccessBeforePassing", JSON_VALUE_NUMBER, (double)config.get_success_times())) return false; if (!json_object_append(obj, "FailuresBeforeCritical", JSON_VALUE_NUMBER, (double)config.get_failure_times())) return false; return true; } static bool create_register_request(const json_value_t *root, const struct ConsulService *service, const ConsulConfig& config) { const json_value_t *val; json_object_t *obj; obj = json_value_object(root); if (!obj) return false; if (!json_object_append(obj, "ID", JSON_VALUE_STRING, service->service_id.c_str())) return false; if (!json_object_append(obj, "Name", JSON_VALUE_STRING, service->service_name.c_str())) return false; if (!service->service_namespace.empty()) { if (!json_object_append(obj, "ns", JSON_VALUE_STRING, service->service_namespace.c_str())) return false; } val = json_object_append(obj, "Tags", JSON_VALUE_ARRAY); if (!val) return false; json_array_t *arr = json_value_array(val); for (const auto& tag : service->tags) { if (!json_array_append(arr, JSON_VALUE_STRING, tag.c_str())) return false; } if (!json_object_append(obj, "Address", JSON_VALUE_STRING, service->service_address.first.c_str())) return false; if (!json_object_append(obj, "Port", JSON_VALUE_NUMBER, (double)service->service_address.second)) return false; val = json_object_append(obj, "Meta", JSON_VALUE_OBJECT); if (!val) return false; json_object_t *meta_obj = json_value_object(val); for (const auto& meta_kv : service->meta) { if (!json_object_append(meta_obj, meta_kv.first.c_str(), JSON_VALUE_STRING, meta_kv.second.c_str())) return false; } int type = service->tag_override ? JSON_VALUE_TRUE : JSON_VALUE_FALSE; if (!json_object_append(obj, "EnableTagOverride", type)) return false; val = json_object_append(obj, "TaggedAddresses", JSON_VALUE_OBJECT); if (!val) return false; json_object_t *tagged_obj = json_value_object(val); if (!tagged_obj) return false; if (!create_tagged_address(service->lan, "lan", tagged_obj)) return false; if (!create_tagged_address(service->lan_ipv4, "lan_ipv4", tagged_obj)) return false; if (!create_tagged_address(service->lan_ipv6, "lan_ipv6", tagged_obj)) return false; if (!create_tagged_address(service->virtual_address, "virtual", tagged_obj)) return false; if (!create_tagged_address(service->wan, "wan", tagged_obj)) return false; if (!create_tagged_address(service->wan_ipv4, "wan_ipv4", tagged_obj)) return false; if (!create_tagged_address(service->wan_ipv6, "wan_ipv6", tagged_obj)) return false; // create health check if (!create_health_check(config, obj)) return false; return true; } static bool parse_list_service_result(const json_value_t *root, std::vector& result) { const json_object_t *obj; const json_value_t *val; const json_array_t *arr; const char *key; const char *str; obj = json_value_object(root); if (!obj) return false; json_object_for_each(key, val, obj) { struct ConsulServiceTags instance; instance.service_name = key; arr = json_value_array(val); if (!arr) return false; const json_value_t *tag_val; json_array_for_each(tag_val, arr) { str = json_value_string(tag_val); if (!str) return false; instance.tags.emplace_back(str); } result.emplace_back(std::move(instance)); } return true; } static bool parse_discover_node(const json_object_t *obj, struct ConsulServiceInstance *instance) { const json_value_t *val; const char *str; val = json_object_find("Node", obj); if (!val) return false; obj = json_value_object(val); if (!obj) return false; val = json_object_find("ID", obj); if (!val) return false; str = json_value_string(val); if (!str) return false; instance->node_id = str; val = json_object_find("Node", obj); if (!val) return false; str = json_value_string(val); if (!str) return false; instance->node_name = str; val = json_object_find("Address", obj); if (!val) return false; str = json_value_string(val); if (!str) return false; instance->node_address = str; val = json_object_find("Datacenter", obj); if (!val) return false; str = json_value_string(val); if (!str) return false; instance->dc = str; val = json_object_find("Meta", obj); if (!val) return false; const json_object_t *meta_obj = json_value_object(val); if (!meta_obj) return false; const char *meta_k; const json_value_t *meta_v; json_object_for_each(meta_k, meta_v, meta_obj) { str = json_value_string(meta_v); if (!str) return false; instance->node_meta[meta_k] = str; } val = json_object_find("CreateIndex", obj); if (val && json_value_type(val) == JSON_VALUE_NUMBER) instance->create_index = json_value_number(val); val = json_object_find("ModifyIndex", obj); if (val && json_value_type(val) == JSON_VALUE_NUMBER) instance->modify_index = json_value_number(val); return true; } static bool parse_tagged_address(const char *name, const json_value_t *tagged_val, ConsulAddress& tagged_address) { const json_value_t *val; const json_object_t *obj; const char *str; obj = json_value_object(tagged_val); if (!obj) return false; val = json_object_find(name, obj); if (!val) return false; obj = json_value_object(val); if (!obj) return false; val = json_object_find("Address", obj); if (!val) return false; str = json_value_string(val); if (!str) return false; tagged_address.first = str; val = json_object_find("Port", obj); if (!val || json_value_type(val) != JSON_VALUE_NUMBER) return false; tagged_address.second = json_value_number(val); return true; } static bool parse_service(const json_object_t *obj, struct ConsulService *service) { const json_value_t *val; const char *str; val = json_object_find("Service", obj); if (!val) return false; obj = json_value_object(val); if (!obj) return false; val = json_object_find("ID", obj); if (!val) return false; str = json_value_string(val); if (!str) return false; service->service_id = str; val = json_object_find("Service", obj); if (!val) return false; str = json_value_string(val); if (!str) return false; service->service_name = str; val = json_object_find("Namespace", obj); if (val) { str = json_value_string(val); if (!str) return false; service->service_namespace = str; } val = json_object_find("Address", obj); if (!val) return false; str = json_value_string(val); if (!str) return false; service->service_address.first = str; val = json_object_find("Port", obj); if (!val) return false; service->service_address.second = json_value_number(val); val = json_object_find("TaggedAddresses", obj); if (!val) return false; parse_tagged_address("lan", val, service->lan); parse_tagged_address("lan_ipv4", val, service->lan_ipv4); parse_tagged_address("lan_ipv6", val, service->lan_ipv6); parse_tagged_address("virtual", val, service->virtual_address); parse_tagged_address("wan", val, service->wan); parse_tagged_address("wan_ipv4", val, service->wan_ipv4); parse_tagged_address("wan_ipv6", val, service->wan_ipv6); val = json_object_find("Tags", obj); if (!val) return false; const json_array_t *tags_arr = json_value_array(val); if (tags_arr) { const json_value_t *tags_value; json_array_for_each(tags_value, tags_arr) { str = json_value_string(tags_value); if (!str) return false; service->tags.emplace_back(str); } } val = json_object_find("Meta", obj); if (!val) return false; const json_object_t *meta_obj = json_value_object(val); if (!meta_obj) return false; const char *meta_k; const json_value_t *meta_v; json_object_for_each(meta_k, meta_v, meta_obj) { str = json_value_string(meta_v); if (!str) return false; service->meta[meta_k] = str; } val = json_object_find("EnableTagOverride", obj); if (val) service->tag_override = (json_value_type(val) == JSON_VALUE_TRUE); return true; } static bool parse_health_check(const json_object_t *obj, struct ConsulServiceInstance *instance) { const json_value_t *val; const char *str; val = json_object_find("Checks", obj); if (!val) return false; const json_array_t *check_arr = json_value_array(val); if (!check_arr) return false; const json_value_t *arr_val; json_array_for_each(arr_val, check_arr) { obj = json_value_object(arr_val); if (!obj) return false; val = json_object_find("ServiceName", obj); if (!val) return false; str = json_value_string(val); if (!str) return false; std::string check_service_name = str; val = json_object_find("ServiceID", obj); if (!val) return false; str = json_value_string(val); if (!str) return false; std::string check_service_id = str; if (check_service_id.empty() || check_service_name.empty()) continue; val = json_object_find("CheckID", obj); if (!val) return false; str = json_value_string(val); if (!str) return false; instance->check_id = str; val = json_object_find("Name", obj); if (!val) return false; str = json_value_string(val); if (!str) return false; instance->check_name = str; val = json_object_find("Status", obj); if (!val) return false; str = json_value_string(val); if (!str) return false; instance->check_status = str; val = json_object_find("Notes", obj); if (val) { str = json_value_string(val); if (!str) return false; instance->check_notes = str; } val = json_object_find("Output", obj); if (val) { str = json_value_string(val); if (!str) return false; instance->check_output = str; } val = json_object_find("Type", obj); if (val) { str = json_value_string(val); if (!str) return false; instance->check_type = str; } break; //only one effective service health check } return true; } static bool parse_discover_result(const json_value_t *root, std::vector& result) { const json_array_t *arr = json_value_array(root); const json_value_t *val; const json_object_t *obj; if (!arr) return false; json_array_for_each(val, arr) { struct ConsulServiceInstance instance; obj = json_value_object(val); if (!obj) return false; if (!parse_discover_node(obj, &instance)) return false; if (!parse_service(obj, &instance.service)) return false; parse_health_check(obj, &instance); result.emplace_back(std::move(instance)); } return true; } static void print_json_string(const char *str, std::string& json_str) { json_str += "\""; while (*str) { switch (*str) { case '\r': json_str += "\\r"; break; case '\n': json_str += "\\n"; break; case '\f': json_str += "\\f"; break; case '\b': json_str += "\\b"; break; case '\"': json_str += "\\\""; break; case '\t': json_str += "\\t"; break; case '\\': json_str += "\\\\"; break; default: if ((unsigned char)*str < 0x20) { char buf[8]; sprintf(buf, "\\u00%02x", *str); json_str += buf; } else json_str += *str; break; } str++; } json_str += "\""; } static void print_json_number(double number, std::string& json_str) { long long integer = number; if (integer == number) json_str += std::to_string(integer); else json_str += std::to_string(number); } static void print_json_object(const json_object_t *obj, int depth, std::string& json_str) { const char *name; const json_value_t *val; int n = 0; int i; json_str += "{\n"; json_object_for_each(name, val, obj) { if (n != 0) json_str += ",\n"; n++; for (i = 0; i < depth + 1; i++) json_str += " "; print_json_string(name, json_str); json_str += ": "; print_json_value(val, depth + 1, json_str); } json_str += "\n"; for (i = 0; i < depth; i++) json_str += " "; json_str += "}"; } static void print_json_array(const json_array_t *arr, int depth, std::string& json_str) { const json_value_t *val; int n = 0; int i; json_str += "[\n"; json_array_for_each(val, arr) { if (n != 0) json_str += ",\n"; n++; for (i = 0; i < depth + 1; i++) json_str += " "; print_json_value(val, depth + 1, json_str); } json_str += "\n"; for (i = 0; i < depth; i++) json_str += " "; json_str += "]"; } static void print_json_value(const json_value_t *val, int depth, std::string& json_str) { switch (json_value_type(val)) { case JSON_VALUE_STRING: print_json_string(json_value_string(val), json_str); break; case JSON_VALUE_NUMBER: print_json_number(json_value_number(val), json_str); break; case JSON_VALUE_OBJECT: print_json_object(json_value_object(val), depth, json_str); break; case JSON_VALUE_ARRAY: print_json_array(json_value_array(val), depth, json_str); break; case JSON_VALUE_TRUE: json_str += "true"; break; case JSON_VALUE_FALSE: json_str += "false"; break; case JSON_VALUE_NULL: json_str += "null"; break; } } workflow-0.11.8/src/client/WFConsulClient.h000066400000000000000000000077561476003635400205570ustar00rootroot00000000000000/* Copyright (c) 2022 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wang Zhenpeng (wangzhenpeng@sogou-inc.com) */ #ifndef _WFCONSULCLIENT_H_ #define _WFCONSULCLIENT_H_ #include #include #include #include #include "HttpMessage.h" #include "WFTaskFactory.h" #include "ConsulDataTypes.h" class WFConsulTask; using consul_callback_t = std::function; enum { CONSUL_API_TYPE_UNKNOWN = 0, CONSUL_API_TYPE_DISCOVER, CONSUL_API_TYPE_LIST_SERVICE, CONSUL_API_TYPE_REGISTER, CONSUL_API_TYPE_DEREGISTER, }; class WFConsulTask : public WFGenericTask { public: bool get_discover_result( std::vector& result); bool get_list_service_result( std::vector& result); public: void set_service(const struct protocol::ConsulService *service); void set_api_type(int api_type) { this->api_type = api_type; } int get_api_type() const { return this->api_type; } void set_callback(consul_callback_t cb) { this->callback = std::move(cb); } void set_consul_index(long long consul_index) { this->consul_index = consul_index; } long long get_consul_index() const { return this->consul_index; } const protocol::HttpResponse *get_http_resp() const { return &this->http_resp; } protected: void set_config(protocol::ConsulConfig conf) { this->config = std::move(conf); } protected: virtual void dispatch(); virtual SubTask *done(); WFHttpTask *create_discover_task(); WFHttpTask *create_list_service_task(); WFHttpTask *create_register_task(); WFHttpTask *create_deregister_task(); std::string generate_discover_request(); long long get_consul_index(protocol::HttpResponse *resp); static bool check_task_result(WFHttpTask *task, WFConsulTask *consul_task); static void discover_callback(WFHttpTask *task); static void list_service_callback(WFHttpTask *task); static void register_callback(WFHttpTask *task); protected: protocol::ConsulConfig config; struct protocol::ConsulService service; std::string proxy_url; int retry_max; int api_type; bool finish; long long consul_index; protocol::HttpResponse http_resp; consul_callback_t callback; protected: WFConsulTask(const std::string& proxy_url, const std::string& service_namespace, const std::string& service_name, const std::string& service_id, int retry_max, consul_callback_t&& cb); virtual ~WFConsulTask() { } friend class WFConsulClient; }; class WFConsulClient { public: // example: http://127.0.0.1:8500 int init(const std::string& proxy_url); int init(const std::string& proxy_url, protocol::ConsulConfig config); void deinit() { } WFConsulTask *create_discover_task(const std::string& service_namespace, const std::string& service_name, int retry_max, consul_callback_t cb); WFConsulTask *create_list_service_task(const std::string& service_namespace, int retry_max, consul_callback_t cb); WFConsulTask *create_register_task(const std::string& service_namespace, const std::string& service_name, const std::string& service_id, int retry_max, consul_callback_t cb); WFConsulTask *create_deregister_task(const std::string& service_namespace, const std::string& service_id, int retry_max, consul_callback_t cb); private: std::string proxy_url; protocol::ConsulConfig config; public: virtual ~WFConsulClient() { } }; #endif workflow-0.11.8/src/client/WFDnsClient.cc000066400000000000000000000144471476003635400201710ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Liu Kai (liukaidx@sogou-inc.com) */ #include #include #include #include "URIParser.h" #include "StringUtil.h" #include "WFDnsClient.h" using namespace protocol; using DnsCtx = std::function; using ComplexTask = WFComplexClientTask; class DnsParams { public: struct dns_params { std::vector uris; std::vector search_list; int ndots; int attempts; bool rotate; }; public: DnsParams() { this->ref = new std::atomic(1); this->params = new dns_params(); } DnsParams(const DnsParams& p) { this->ref = p.ref; this->params = p.params; this->incref(); } DnsParams& operator=(const DnsParams& p) { if (this != &p) { this->decref(); this->ref = p.ref; this->params = p.params; this->incref(); } return *this; } ~DnsParams() { this->decref(); } const dns_params *get_params() const { return this->params; } dns_params *get_params() { return this->params; } private: void incref() { (*this->ref)++; } void decref() { if (--*this->ref == 0) { delete this->params; delete this->ref; } } private: dns_params *params; std::atomic *ref; }; enum { DNS_STATUS_TRY_ORIGIN_DONE = 0, DNS_STATUS_TRY_ORIGIN_FIRST = 1, DNS_STATUS_TRY_ORIGIN_LAST = 2 }; struct DnsStatus { std::string origin_name; std::string current_name; size_t next_server; // next server to try size_t last_server; // last server to try size_t next_domain; // next search domain to try int attempts_left; int try_origin_state; }; static int __get_ndots(const std::string& s) { int ndots = 0; for (size_t i = 0; i < s.size(); i++) ndots += s[i] == '.'; return ndots; } static bool __has_next_name(const DnsParams::dns_params *p, struct DnsStatus *s) { if (s->try_origin_state == DNS_STATUS_TRY_ORIGIN_FIRST) { s->current_name = s->origin_name; s->try_origin_state = DNS_STATUS_TRY_ORIGIN_DONE; return true; } if (s->next_domain < p->search_list.size()) { s->current_name = s->origin_name; s->current_name.push_back('.'); s->current_name.append(p->search_list[s->next_domain]); s->next_domain++; return true; } if (s->try_origin_state == DNS_STATUS_TRY_ORIGIN_LAST) { s->current_name = s->origin_name; s->try_origin_state = DNS_STATUS_TRY_ORIGIN_DONE; return true; } return false; } static void __callback_internal(WFDnsTask *task, const DnsParams& params, struct DnsStatus& s) { ComplexTask *ctask = static_cast(task); int state = task->get_state(); DnsRequest *req = task->get_req(); DnsResponse *resp = task->get_resp(); const auto *p = params.get_params(); int rcode = resp->get_rcode(); bool try_next_server = state != WFT_STATE_SUCCESS || rcode == DNS_RCODE_SERVER_FAILURE || rcode == DNS_RCODE_NOT_IMPLEMENTED || rcode == DNS_RCODE_REFUSED; bool try_next_name = rcode == DNS_RCODE_FORMAT_ERROR || rcode == DNS_RCODE_NAME_ERROR || resp->get_ancount() == 0; if (try_next_server) { if (s.last_server == s.next_server) s.attempts_left--; if (s.attempts_left <= 0) return; s.next_server = (s.next_server + 1) % p->uris.size(); ctask->set_redirect(p->uris[s.next_server]); return; } if (try_next_name && __has_next_name(p, &s)) { req->set_question_name(s.current_name.c_str()); ctask->set_redirect(p->uris[s.next_server]); return; } } int WFDnsClient::init(const std::string& url) { return this->init(url, "", 1, 2, false); } int WFDnsClient::init(const std::string& url, const std::string& search_list, int ndots, int attempts, bool rotate) { std::vector hosts; std::vector uris; std::string host; ParsedURI uri; this->id = 0; hosts = StringUtil::split_filter_empty(url, ','); for (size_t i = 0; i < hosts.size(); i++) { host = hosts[i]; if (strncasecmp(host.c_str(), "dns://", 6) != 0 && strncasecmp(host.c_str(), "dnss://", 7) != 0) { host = "dns://" + host; } if (URIParser::parse(host, uri) != 0) return -1; uris.emplace_back(std::move(uri)); } if (uris.empty() || ndots < 0 || attempts < 1) { errno = EINVAL; return -1; } this->params = new DnsParams; DnsParams::dns_params *q = ((DnsParams *)this->params)->get_params(); q->uris = std::move(uris); q->search_list = StringUtil::split_filter_empty(search_list, ','); q->ndots = ndots > 15 ? 15 : ndots; q->attempts = attempts > 5 ? 5 : attempts; q->rotate = rotate; return 0; } void WFDnsClient::deinit() { delete (DnsParams *)this->params; this->params = NULL; } WFDnsTask *WFDnsClient::create_dns_task(const std::string& name, dns_callback_t callback) { DnsParams::dns_params *p = ((DnsParams *)this->params)->get_params(); struct DnsStatus status; size_t next_server; WFDnsTask *task; DnsRequest *req; next_server = p->rotate ? this->id++ % p->uris.size() : 0; status.origin_name = name; status.next_domain = 0; status.attempts_left = p->attempts; status.try_origin_state = DNS_STATUS_TRY_ORIGIN_FIRST; if (!name.empty() && name.back() == '.') status.next_domain = p->search_list.size(); else if (__get_ndots(name) < p->ndots) status.try_origin_state = DNS_STATUS_TRY_ORIGIN_LAST; __has_next_name(p, &status); task = WFTaskFactory::create_dns_task(p->uris[next_server], 0, std::move(callback)); status.next_server = next_server; status.last_server = (next_server + p->uris.size() - 1) % p->uris.size(); req = task->get_req(); req->set_question(status.current_name.c_str(), DNS_TYPE_A, DNS_CLASS_IN); req->set_rd(1); ComplexTask *ctask = static_cast(task); *ctask->get_mutable_ctx() = std::bind(__callback_internal, std::placeholders::_1, *(DnsParams *)params, status); return task; } workflow-0.11.8/src/client/WFDnsClient.h000066400000000000000000000021651476003635400200250ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Liu Kai (liukaidx@sogou-inc.com) */ #ifndef _WFDNSCLIENT_H_ #define _WFDNSCLIENT_H_ #include #include #include "WFTaskFactory.h" #include "DnsMessage.h" class WFDnsClient { public: int init(const std::string& url); int init(const std::string& url, const std::string& search_list, int ndots, int attempts, bool rotate); void deinit(); WFDnsTask *create_dns_task(const std::string& name, dns_callback_t callback); private: void *params; std::atomic id; public: virtual ~WFDnsClient() { } }; #endif workflow-0.11.8/src/client/WFKafkaClient.cc000066400000000000000000001232251476003635400204550ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wang Zhulei (wangzhulei@sogou-inc.com) Xie Han (xiehan@sogou-inc.com) Liu Kai (liukaidx@sogou-inc.com) */ #include #include #include #include #include #include #include #include #include #include "WFTaskError.h" #include "StringUtil.h" #include "KafkaTaskImpl.inl" #include "WFKafkaClient.h" #define KAFKA_HEARTBEAT_INTERVAL (3 * 1000 * 1000) #define KAFKA_CGROUP_UNINIT 0 #define KAFKA_CGROUP_DOING 1 #define KAFKA_CGROUP_DONE 2 #define KAFKA_CGROUP_NONE 3 #define KAFKA_HEARTBEAT_UNINIT 0 #define KAFKA_HEARTBEAT_DOING 1 #define KAFKA_HEARTBEAT_DONE 2 #define KAFKA_DEINIT (1<<30) using namespace protocol; using ComplexKafkaTask = WFComplexClientTask; class KafkaMember { public: KafkaMember() : scheme("kafka://"), ref(1) { this->transport_type = TT_TCP; this->cgroup_status = KAFKA_CGROUP_NONE; this->heartbeat_status = KAFKA_HEARTBEAT_UNINIT; this->meta_doing = false; this->cgroup_outdated = false; this->client_deinit = false; this->heartbeat_series = NULL; } void incref() { ++this->ref; } void decref() { if (--this->ref == 0) delete this; } enum TransportType transport_type; std::string scheme; std::vector broker_hosts; SSL_CTX *ssl_ctx; KafkaCgroup cgroup; KafkaMetaList meta_list; KafkaBrokerMap broker_map; KafkaConfig config; std::map meta_status; std::mutex mutex; char cgroup_status; char heartbeat_status; bool meta_doing; bool cgroup_outdated; bool client_deinit; void *heartbeat_series; size_t cgroup_wait_cnt; size_t meta_wait_cnt; std::atomic ref; }; class KafkaClientTask : public WFKafkaTask { public: KafkaClientTask(const std::string& query, int retry_max, kafka_callback_t&& callback, WFKafkaClient *client) : WFKafkaTask(retry_max, std::move(callback)) { this->api_type = Kafka_Unknown; this->kafka_error = 0; this->member = client->member; this->query = query; this->member->incref(); this->member->mutex.lock(); this->config = client->member->config; if (!this->member->broker_hosts.empty()) { int rpos = rand() % this->member->broker_hosts.size(); this->url = this->member->broker_hosts.at(rpos); } this->member->mutex.unlock(); this->info_generated = false; this->msg = NULL; } virtual ~KafkaClientTask() { this->member->decref(); } std::string *get_url() { return &this->url; } protected: virtual bool add_topic(const std::string& topic); virtual bool add_toppar(const KafkaToppar& toppar); virtual bool add_produce_record(const std::string& topic, int partition, KafkaRecord record); virtual bool add_offset_toppar(const KafkaToppar& toppar); virtual void dispatch(); virtual void parse_query(); virtual void generate_info(); private: static void kafka_meta_callback(__WFKafkaTask *task); static void kafka_merge_meta_list(KafkaMetaList *dst, KafkaMetaList *src); static void kafka_merge_broker_list(const std::string& scheme, std::vector *hosts, KafkaBrokerMap *dst, KafkaBrokerList *src); static void kafka_cgroup_callback(__WFKafkaTask *task); static void kafka_offsetcommit_callback(__WFKafkaTask *task); static void kafka_parallel_callback(const ParallelWork *pwork); static void kafka_timer_callback(WFTimerTask *task); static void kafka_heartbeat_callback(__WFKafkaTask *task); static void kafka_leavegroup_callback(__WFKafkaTask *task); static void kafka_rebalance_proc(KafkaMember *member, SeriesWork *series); static void kafka_rebalance_callback(__WFKafkaTask *task); void kafka_move_task_callback(__WFKafkaTask *task); void kafka_process_toppar_offset(KafkaToppar *task_toppar); bool compare_topics(KafkaClientTask *task); bool check_cgroup(); bool check_meta(); int arrange_toppar(int api_type); int arrange_produce(); int arrange_fetch(); int arrange_commit(); int arrange_offset(); int dispatch_locked(); KafkaBroker *get_broker(int node_id) { return this->member->broker_map.find_item(node_id); } int get_node_id(const KafkaToppar *toppar); bool get_meta_status(KafkaMetaList **uninit_meta_list); void set_meta_status(bool status); std::string get_userinfo() { return this->userinfo; } private: KafkaMember *member; KafkaBroker broker; std::map toppar_list_map; std::string url; std::string query; std::set topic_set; std::string userinfo; bool info_generated; bool wait_cgroup; void *msg; friend class WFKafkaClient; }; int KafkaClientTask::get_node_id(const KafkaToppar *toppar) { int preferred_read_replica = toppar->get_preferred_read_replica(); if (preferred_read_replica >= 0) return preferred_read_replica; bool flag = false; this->member->meta_list.rewind(); KafkaMeta *meta; while ((meta = this->member->meta_list.get_next()) != NULL) { if (strcmp(meta->get_topic(), toppar->get_topic()) == 0) { flag = true; break; } } const kafka_broker_t *broker = NULL; if (flag) broker = meta->get_broker(toppar->get_partition()); if (!broker) return -1; return broker->node_id; } void KafkaClientTask::kafka_offsetcommit_callback(__WFKafkaTask *task) { KafkaClientTask *t = (KafkaClientTask *)task->user_data; if (task->get_state() == WFT_STATE_SUCCESS) t->result.set_resp(std::move(*task->get_resp()), 0); t->finish = true; t->state = task->get_state(); t->error = task->get_error(); t->kafka_error = *static_cast(task)->get_mutable_ctx(); } void KafkaClientTask::kafka_leavegroup_callback(__WFKafkaTask *task) { KafkaClientTask *t = (KafkaClientTask *)task->user_data; t->finish = true; t->state = task->get_state(); t->error = task->get_error(); t->kafka_error = *static_cast(task)->get_mutable_ctx(); } void KafkaClientTask::kafka_rebalance_callback(__WFKafkaTask *task) { KafkaMember *member = (KafkaMember *)task->user_data; SeriesWork *series = series_of(task); size_t max; member->mutex.lock(); if (member->client_deinit) { member->mutex.unlock(); member->decref(); return; } if (task->get_state() == WFT_STATE_SUCCESS) { member->cgroup_status = KAFKA_CGROUP_DONE; member->cgroup = std::move(*(task->get_resp()->get_cgroup())); if (member->heartbeat_status == KAFKA_HEARTBEAT_UNINIT) { __WFKafkaTask *kafka_task; KafkaBroker *coordinator = member->cgroup.get_coordinator(); kafka_task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, coordinator->get_host(), coordinator->get_port(), member->ssl_ctx, "", 0, kafka_heartbeat_callback); kafka_task->user_data = member; kafka_task->get_req()->set_api_type(Kafka_Heartbeat); kafka_task->get_req()->set_cgroup(member->cgroup); kafka_task->get_req()->set_broker(*coordinator); series->push_back(kafka_task); member->heartbeat_status = KAFKA_HEARTBEAT_DOING; member->heartbeat_series = series; } max = member->cgroup_wait_cnt; char name[64]; snprintf(name, 64, "%p.cgroup", member); member->mutex.unlock(); WFTaskFactory::signal_by_name(name, NULL, max); } else { kafka_rebalance_proc(member, series); member->mutex.unlock(); } } void KafkaClientTask::kafka_rebalance_proc(KafkaMember *member, SeriesWork *series) { KafkaBroker *coordinator = member->cgroup.get_coordinator(); __WFKafkaTask *task; task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, coordinator->get_host(), coordinator->get_port(), member->ssl_ctx, "", 0, kafka_rebalance_callback); task->user_data = member; task->get_req()->set_config(member->config); task->get_req()->set_api_type(Kafka_FindCoordinator); task->get_req()->set_cgroup(member->cgroup); task->get_req()->set_meta_list(member->meta_list); member->cgroup_status = KAFKA_CGROUP_DOING; member->heartbeat_status = KAFKA_HEARTBEAT_UNINIT; member->cgroup_outdated = false; series->push_back(task); } void KafkaClientTask::kafka_heartbeat_callback(__WFKafkaTask *task) { KafkaMember *member = (KafkaMember *)task->user_data; SeriesWork *series = series_of(task); KafkaResponse *resp = task->get_resp(); member->mutex.lock(); if (member->client_deinit || member->heartbeat_series != series) { member->mutex.unlock(); member->decref(); return; } if (resp->get_cgroup()->get_error() == 0) { member->heartbeat_status = KAFKA_HEARTBEAT_DONE; WFTimerTask *timer_task; timer_task = WFTaskFactory::create_timer_task(KAFKA_HEARTBEAT_INTERVAL, kafka_timer_callback); timer_task->user_data = member; series->push_back(timer_task); } else kafka_rebalance_proc(member, series); member->mutex.unlock(); } void KafkaClientTask::kafka_timer_callback(WFTimerTask *task) { KafkaMember *member = (KafkaMember *)task->user_data; SeriesWork *series = series_of(task); member->mutex.lock(); if (member->client_deinit || member->heartbeat_series != series) { member->mutex.unlock(); member->decref(); return; } member->heartbeat_status = KAFKA_HEARTBEAT_DOING; __WFKafkaTask *kafka_task; KafkaBroker *coordinator = member->cgroup.get_coordinator(); kafka_task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, coordinator->get_host(), coordinator->get_port(), member->ssl_ctx, "", 0, kafka_heartbeat_callback); kafka_task->user_data = member; kafka_task->get_req()->set_config(member->config); kafka_task->get_req()->set_api_type(Kafka_Heartbeat); kafka_task->get_req()->set_cgroup(member->cgroup); kafka_task->get_req()->set_broker(*coordinator); series->push_back(kafka_task); member->mutex.unlock(); } void KafkaClientTask::kafka_merge_meta_list(KafkaMetaList *dst, KafkaMetaList *src) { src->rewind(); KafkaMeta *src_meta; while ((src_meta = src->get_next()) != NULL) { dst->rewind(); KafkaMeta *dst_meta; while ((dst_meta = dst->get_next()) != NULL) { if (strcmp(dst_meta->get_topic(), src_meta->get_topic()) == 0) { dst->del_cur(); delete dst_meta; break; } } dst->add_item(*src_meta); } } void KafkaClientTask::kafka_merge_broker_list(const std::string& scheme, std::vector *hosts, KafkaBrokerMap *dst, KafkaBrokerList *src) { hosts->clear(); src->rewind(); KafkaBroker *src_broker; while ((src_broker = src->get_next()) != NULL) { std::string host = scheme + src_broker->get_host() + ":" + std::to_string(src_broker->get_port()); hosts->emplace_back(std::move(host)); if (!dst->find_item(src_broker->get_node_id())) dst->add_item(*src_broker, src_broker->get_node_id()); } } void KafkaClientTask::kafka_meta_callback(__WFKafkaTask *task) { KafkaClientTask *t = (KafkaClientTask *)task->user_data; void *msg = NULL; size_t max; t->member->mutex.lock(); t->state = task->get_state(); t->error = task->get_error(); t->kafka_error = *static_cast(task)->get_mutable_ctx(); if (t->state == WFT_STATE_SUCCESS) { kafka_merge_meta_list(&t->member->meta_list, task->get_resp()->get_meta_list()); t->meta_list.rewind(); KafkaMeta *meta; while ((meta = t->meta_list.get_next()) != NULL) (t->member->meta_status)[meta->get_topic()] = true; kafka_merge_broker_list(t->member->scheme, &t->member->broker_hosts, &t->member->broker_map, task->get_resp()->get_broker_list()); } else { t->meta_list.rewind(); KafkaMeta *meta; while ((meta = t->meta_list.get_next()) != NULL) (t->member->meta_status)[meta->get_topic()] = false; t->finish = true; msg = t; } t->member->meta_doing = false; max = t->member->meta_wait_cnt; char name[64]; snprintf(name, 64, "%p.meta", t->member); t->member->mutex.unlock(); WFTaskFactory::signal_by_name(name, msg, max); } void KafkaClientTask::kafka_cgroup_callback(__WFKafkaTask *task) { KafkaClientTask *t = (KafkaClientTask *)task->user_data; KafkaMember *member = t->member; SeriesWork *heartbeat_series = NULL; void *msg = NULL; size_t max; member->mutex.lock(); t->state = task->get_state(); t->error = task->get_error(); t->kafka_error = *static_cast(task)->get_mutable_ctx(); if (t->state == WFT_STATE_SUCCESS) { member->cgroup = std::move(*(task->get_resp()->get_cgroup())); kafka_merge_meta_list(&member->meta_list, task->get_resp()->get_meta_list()); t->meta_list.rewind(); KafkaMeta *meta; while ((meta = t->meta_list.get_next()) != NULL) (member->meta_status)[meta->get_topic()] = true; kafka_merge_broker_list(member->scheme, &member->broker_hosts, &member->broker_map, task->get_resp()->get_broker_list()); member->cgroup_status = KAFKA_CGROUP_DONE; if (member->heartbeat_status == KAFKA_HEARTBEAT_UNINIT) { __WFKafkaTask *kafka_task; KafkaBroker *coordinator = member->cgroup.get_coordinator(); kafka_task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, coordinator->get_host(), coordinator->get_port(), member->ssl_ctx, "", 0, kafka_heartbeat_callback); kafka_task->user_data = member; member->incref(); kafka_task->get_req()->set_config(member->config); kafka_task->get_req()->set_api_type(Kafka_Heartbeat); kafka_task->get_req()->set_cgroup(member->cgroup); kafka_task->get_req()->set_broker(*coordinator); heartbeat_series = Workflow::create_series_work(kafka_task, nullptr); member->heartbeat_status = KAFKA_HEARTBEAT_DOING; member->heartbeat_series = heartbeat_series; } } else { member->cgroup_status = KAFKA_CGROUP_UNINIT; member->heartbeat_status = KAFKA_HEARTBEAT_UNINIT; member->heartbeat_series = NULL; t->finish = true; msg = t; } max = member->cgroup_wait_cnt; char name[64]; snprintf(name, 64, "%p.cgroup", member); member->mutex.unlock(); WFTaskFactory::signal_by_name(name, msg, max); if (heartbeat_series) heartbeat_series->start(); } void KafkaClientTask::kafka_parallel_callback(const ParallelWork *pwork) { KafkaClientTask *t = (KafkaClientTask *)pwork->get_context(); t->finish = true; t->state = WFT_STATE_TASK_ERROR; t->error = 0; std::pair *state_error; bool flag = false; int16_t state = WFT_STATE_SUCCESS; int16_t error = 0; int kafka_error = 0; for (size_t i = 0; i < pwork->size(); i++) { state_error = (std::pair *)pwork->series_at(i)->get_context(); if ((state_error->first >> 16) != WFT_STATE_SUCCESS) { if (!flag) { flag = true; t->member->mutex.lock(); t->set_meta_status(false); t->member->mutex.unlock(); } state = state_error->first >> 16; error = state_error->first & 0xffff; kafka_error = state_error->second; } else { t->state = WFT_STATE_SUCCESS; } delete state_error; } if (t->state != WFT_STATE_SUCCESS) { t->state = state; t->error = error; t->kafka_error = kafka_error; } } void KafkaClientTask::kafka_process_toppar_offset(KafkaToppar *task_toppar) { KafkaToppar *toppar; struct list_head *pos; list_for_each(pos, this->member->cgroup.get_assigned_toppar_list()) { toppar = this->member->cgroup.get_assigned_toppar_by_pos(pos); if (strcmp(toppar->get_topic(), task_toppar->get_topic()) == 0 && toppar->get_partition() == task_toppar->get_partition()) { long long offset = task_toppar->get_offset() - 1; KafkaRecord *last_record = task_toppar->get_tail_record(); if (last_record) offset = last_record->get_offset(); toppar->set_offset(offset + 1); toppar->set_low_watermark(task_toppar->get_low_watermark()); toppar->set_high_watermark(task_toppar->get_high_watermark()); } } } void KafkaClientTask::kafka_move_task_callback(__WFKafkaTask *task) { auto *state_error = new std::pair; int16_t state = task->get_state(); int16_t error = task->get_error(); /* 'state' is always positive. */ state_error->first = (state << 16) | error; state_error->second = *static_cast(task)->get_mutable_ctx(); series_of(task)->set_context(state_error); KafkaTopparList *toppar_list = task->get_resp()->get_toppar_list(); if (task->get_state() == WFT_STATE_SUCCESS && task->get_resp()->get_api_type() == Kafka_Fetch) { toppar_list->rewind(); KafkaToppar *task_toppar; while ((task_toppar = toppar_list->get_next()) != NULL) kafka_process_toppar_offset(task_toppar); } if (task->get_state() == WFT_STATE_SUCCESS) { long idx = (long)(task->user_data); this->result.set_resp(std::move(*task->get_resp()), idx); } } void KafkaClientTask::generate_info() { if (this->info_generated) return; if (this->config.get_sasl_mech()) { const char *username = this->config.get_sasl_username(); const char *password = this->config.get_sasl_password(); this->userinfo.clear(); if (username) this->userinfo += StringUtil::url_encode_component(username); this->userinfo += ":"; if (password) this->userinfo += StringUtil::url_encode_component(password); this->userinfo += ":"; this->userinfo += this->config.get_sasl_mech(); this->userinfo += ":"; this->userinfo += std::to_string((intptr_t)this->member); } else { char buf[64]; snprintf(buf, 64, "user:pass:sasl:%p", this->member); this->userinfo = buf; } const char *hostport = this->url.c_str() + this->member->scheme.size(); this->url = this->member->scheme + this->userinfo + "@" + hostport; this->info_generated = true; } void KafkaClientTask::parse_query() { auto query_kv = URIParser::split_query_strict(this->query); int api_type = this->api_type; for (const auto &kv : query_kv) { if (strcasecmp(kv.first.c_str(), "api") == 0 && api_type == Kafka_Unknown) { for (auto& v : kv.second) { if (strcasecmp(v.c_str(), "fetch") == 0) this->api_type = Kafka_Fetch; else if (strcasecmp(v.c_str(), "produce") == 0) this->api_type = Kafka_Produce; else if (strcasecmp(v.c_str(), "commit") == 0) this->api_type = Kafka_OffsetCommit; else if (strcasecmp(v.c_str(), "meta") == 0) this->api_type = Kafka_Metadata; else if (strcasecmp(v.c_str(), "leavegroup") == 0) this->api_type = Kafka_LeaveGroup; else if (strcasecmp(v.c_str(), "listoffsets") == 0) this->api_type = Kafka_ListOffsets; } } else if (strcasecmp(kv.first.c_str(), "topic") == 0) { for (auto& v : kv.second) this->add_topic(v); } } } bool KafkaClientTask::get_meta_status(KafkaMetaList **uninit_meta_list) { this->meta_list.rewind(); KafkaMeta *meta; std::set unique; bool status = true; while ((meta = this->meta_list.get_next()) != NULL) { if (!unique.insert(meta->get_topic()).second) continue; if (!this->member->meta_status[meta->get_topic()]) { if (status) { *uninit_meta_list = new KafkaMetaList; status = false; } (*uninit_meta_list)->add_item(*meta); } } return status; } void KafkaClientTask::set_meta_status(bool status) { this->member->meta_list.rewind(); KafkaMeta *meta; while ((meta = this->member->meta_list.get_next()) != NULL) this->member->meta_status[meta->get_topic()] = false; } bool KafkaClientTask::compare_topics(KafkaClientTask *task) { auto first1 = topic_set.cbegin(), last1 = topic_set.cend(); auto first2 = task->topic_set.cbegin(), last2 = task->topic_set.cend(); int cmp; // check whether task->topic_set is a subset of topic_set while (first1 != last1 && first2 != last2) { cmp = first1->compare(*first2); if (cmp == 0) { ++first1; ++first2; } else if (cmp < 0) ++first1; else return false; } return first2 == last2; } bool KafkaClientTask::check_cgroup() { KafkaMember *member = this->member; if (member->cgroup_outdated && member->cgroup_status != KAFKA_CGROUP_DOING) { member->cgroup_outdated = false; member->cgroup_status = KAFKA_CGROUP_UNINIT; member->heartbeat_series = NULL; member->heartbeat_status = KAFKA_HEARTBEAT_UNINIT; } if (member->cgroup_status == KAFKA_CGROUP_DOING) { WFConditional *cond; char name[64]; snprintf(name, 64, "%p.cgroup", this->member); this->wait_cgroup = true; cond = WFTaskFactory::create_conditional(name, this, &this->msg); series_of(this)->push_front(cond); member->cgroup_wait_cnt++; return false; } if ((this->api_type == Kafka_Fetch || this->api_type == Kafka_OffsetCommit) && (member->cgroup_status == KAFKA_CGROUP_UNINIT)) { __WFKafkaTask *task; task = __WFKafkaTaskFactory::create_kafka_task(this->url, member->ssl_ctx, this->retry_max, kafka_cgroup_callback); task->user_data = this; task->get_req()->set_config(this->config); task->get_req()->set_api_type(Kafka_FindCoordinator); task->get_req()->set_cgroup(member->cgroup); task->get_req()->set_meta_list(member->meta_list); series_of(this)->push_front(this); series_of(this)->push_front(task); member->cgroup_status = KAFKA_CGROUP_DOING; member->cgroup_wait_cnt = 0; return false; } return true; } bool KafkaClientTask::check_meta() { KafkaMember *member = this->member; KafkaMetaList *uninit_meta_list; if (this->get_meta_status(&uninit_meta_list)) return true; if (member->meta_doing) { WFConditional *cond; char name[64]; snprintf(name, 64, "%p.meta", this->member); this->wait_cgroup = false; cond = WFTaskFactory::create_conditional(name, this, &this->msg); series_of(this)->push_front(cond); member->meta_wait_cnt++; } else { __WFKafkaTask *task; task = __WFKafkaTaskFactory::create_kafka_task(this->url, member->ssl_ctx, this->retry_max, kafka_meta_callback); task->user_data = this; task->get_req()->set_config(this->config); task->get_req()->set_api_type(Kafka_Metadata); task->get_req()->set_meta_list(*uninit_meta_list); series_of(this)->push_front(this); series_of(this)->push_front(task); member->meta_wait_cnt = 0; member->meta_doing = true; } delete uninit_meta_list; return false; } int KafkaClientTask::dispatch_locked() { KafkaMember *member = this->member; KafkaBroker *coordinator; __WFKafkaTask *task; ParallelWork *parallel; SeriesWork *series; if (this->check_cgroup() == false) return member->cgroup_wait_cnt > 0; if (this->check_meta() == false) return member->meta_wait_cnt > 0; if (arrange_toppar(this->api_type) < 0) { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_KAFKA_ARRANGE_FAILED; this->finish = true; return 0; } if (this->member->cgroup_outdated) { series_of(this)->push_front(this); return 0; } switch(this->api_type) { case Kafka_Produce: if (this->toppar_list_map.size() == 0) { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_KAFKA_PRODUCE_FAILED; this->finish = true; break; } parallel = Workflow::create_parallel_work(kafka_parallel_callback); this->result.create(this->toppar_list_map.size()); parallel->set_context(this); for (auto &v : this->toppar_list_map) { auto cb = std::bind(&KafkaClientTask::kafka_move_task_callback, this, std::placeholders::_1); KafkaBroker *broker = get_broker(v.first); task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, broker->get_host(), broker->get_port(), member->ssl_ctx, this->get_userinfo(), this->retry_max, std::move(cb)); task->get_req()->set_config(this->config); task->get_req()->set_toppar_list(v.second); task->get_req()->set_broker(*broker); task->get_req()->set_api_type(Kafka_Produce); task->user_data = (void *)parallel->size(); series = Workflow::create_series_work(task, nullptr); parallel->add_series(series); } series_of(this)->push_front(this); series_of(this)->push_front(parallel); break; case Kafka_Fetch: if (this->toppar_list_map.size() == 0) { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_KAFKA_FETCH_FAILED; this->finish = true; break; } parallel = Workflow::create_parallel_work(kafka_parallel_callback); this->result.create(this->toppar_list_map.size()); parallel->set_context(this); for (auto &v : this->toppar_list_map) { auto cb = std::bind(&KafkaClientTask::kafka_move_task_callback, this, std::placeholders::_1); KafkaBroker *broker = get_broker(v.first); task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, broker->get_host(), broker->get_port(), member->ssl_ctx, this->get_userinfo(), this->retry_max, std::move(cb)); task->get_req()->set_config(this->config); task->get_req()->set_toppar_list(v.second); task->get_req()->set_broker(*broker); task->get_req()->set_api_type(Kafka_Fetch); task->user_data = (void *)parallel->size(); series = Workflow::create_series_work(task, nullptr); parallel->add_series(series); } series_of(this)->push_front(this); series_of(this)->push_front(parallel); break; case Kafka_Metadata: this->finish = true; break; case Kafka_OffsetCommit: if (!member->cgroup.get_group()) { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_KAFKA_COMMIT_FAILED; this->finish = true; break; } this->result.create(1); coordinator = member->cgroup.get_coordinator(); task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, coordinator->get_host(), coordinator->get_port(), member->ssl_ctx, this->get_userinfo(), this->retry_max, kafka_offsetcommit_callback); task->user_data = this; task->get_req()->set_config(this->config); task->get_req()->set_cgroup(member->cgroup); task->get_req()->set_broker(*coordinator); task->get_req()->set_toppar_list(this->toppar_list); task->get_req()->set_api_type(this->api_type); series_of(this)->push_front(this); series_of(this)->push_front(task); break; case Kafka_LeaveGroup: if (!member->cgroup.get_group()) { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_KAFKA_LEAVEGROUP_FAILED; this->finish = true; break; } coordinator = member->cgroup.get_coordinator(); if (!coordinator->get_host()) break; task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, coordinator->get_host(), coordinator->get_port(), member->ssl_ctx, this->get_userinfo(), 0, kafka_leavegroup_callback); task->user_data = this; task->get_req()->set_config(this->config); task->get_req()->set_api_type(Kafka_LeaveGroup); task->get_req()->set_broker(*coordinator); task->get_req()->set_cgroup(member->cgroup); series_of(this)->push_front(this); series_of(this)->push_front(task); break; case Kafka_ListOffsets: if (this->toppar_list_map.size() == 0) { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_KAFKA_LIST_OFFSETS_FAILED; this->finish = true; break; } parallel = Workflow::create_parallel_work(kafka_parallel_callback); this->result.create(this->toppar_list_map.size()); parallel->set_context(this); for (auto &v : this->toppar_list_map) { auto cb = std::bind(&KafkaClientTask::kafka_move_task_callback, this, std::placeholders::_1); KafkaBroker *broker = get_broker(v.first); task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, broker->get_host(), broker->get_port(), member->ssl_ctx, this->get_userinfo(), this->retry_max, std::move(cb)); task->get_req()->set_config(this->config); task->get_req()->set_toppar_list(v.second); task->get_req()->set_broker(*broker); task->get_req()->set_api_type(Kafka_ListOffsets); task->user_data = (void *)parallel->size(); series = Workflow::create_series_work(task, nullptr); parallel->add_series(series); } series_of(this)->push_front(this); series_of(this)->push_front(parallel); break; default: this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_KAFKA_API_UNKNOWN; this->finish = true; break; } return 0; } void KafkaClientTask::dispatch() { if (this->finish) { this->subtask_done(); return; } if (this->msg) { KafkaClientTask *task = static_cast(this->msg); if (this->wait_cgroup || this->compare_topics(task) == true) { this->state = task->get_state(); this->error = task->get_error(); this->kafka_error = get_kafka_error(); this->finish = true; this->subtask_done(); return; } this->msg = NULL; } if (!this->query.empty()) this->parse_query(); this->generate_info(); int flag; this->member->mutex.lock(); flag = this->dispatch_locked(); if (flag) this->subtask_done(); this->member->mutex.unlock(); if (!flag) this->subtask_done(); } bool KafkaClientTask::add_topic(const std::string& topic) { bool flag = false; this->member->mutex.lock(); this->topic_set.insert(topic); this->member->meta_list.rewind(); KafkaMeta *meta; while ((meta = this->member->meta_list.get_next()) != NULL) { if (meta->get_topic() == topic) { flag = true; break; } } if (!flag) { this->member->meta_status[topic] = false; KafkaMeta tmp; if (!tmp.set_topic(topic)) { this->member->mutex.unlock(); return false; } this->meta_list.add_item(tmp); this->member->meta_list.add_item(tmp); if (this->member->cgroup.get_group()) this->member->cgroup_outdated = true; } else { this->meta_list.rewind(); KafkaMeta *exist; while ((exist = this->meta_list.get_next()) != NULL) { if (strcmp(exist->get_topic(), meta->get_topic()) == 0) { this->member->mutex.unlock(); return true; } } this->meta_list.add_item(*meta); } this->member->mutex.unlock(); return true; } bool KafkaClientTask::add_toppar(const KafkaToppar& toppar) { if (this->member->cgroup.get_group()) return false; bool flag = false; this->member->mutex.lock(); this->member->meta_list.rewind(); KafkaMeta *meta; while ((meta = this->member->meta_list.get_next()) != NULL) { if (strcmp(meta->get_topic(), toppar.get_topic()) == 0) { flag = true; break; } } this->topic_set.insert(toppar.get_topic()); if (!flag) { KafkaMeta tmp; if (!tmp.set_topic(toppar.get_topic())) { this->member->mutex.unlock(); return false; } KafkaToppar new_toppar; if (!new_toppar.set_topic_partition(toppar.get_topic(), toppar.get_partition())) { this->member->mutex.unlock(); return false; } new_toppar.set_offset(toppar.get_offset()); new_toppar.set_offset_timestamp(toppar.get_offset_timestamp()); new_toppar.set_low_watermark(toppar.get_low_watermark()); new_toppar.set_high_watermark(toppar.get_high_watermark()); this->toppar_list.add_item(new_toppar); this->meta_list.add_item(tmp); this->member->meta_list.add_item(tmp); if (this->member->cgroup.get_group()) this->member->cgroup_outdated = true; } else { this->toppar_list.rewind(); KafkaToppar *exist; while ((exist = this->toppar_list.get_next()) != NULL) { if (strcmp(exist->get_topic(), toppar.get_topic()) == 0 && exist->get_partition() == toppar.get_partition()) { this->member->mutex.unlock(); return true; } } KafkaToppar new_toppar; if (!new_toppar.set_topic_partition(toppar.get_topic(), toppar.get_partition())) { this->member->mutex.unlock(); return true; } new_toppar.set_offset(toppar.get_offset()); new_toppar.set_offset_timestamp(toppar.get_offset_timestamp()); new_toppar.set_low_watermark(toppar.get_low_watermark()); new_toppar.set_high_watermark(toppar.get_high_watermark()); this->toppar_list.add_item(new_toppar); this->meta_list.add_item(*meta); } this->member->mutex.unlock(); return true; } bool KafkaClientTask::add_produce_record(const std::string& topic, int partition, KafkaRecord record) { if (!add_topic(topic)) return false; bool flag = false; this->toppar_list.rewind(); KafkaToppar *toppar; while ((toppar = this->toppar_list.get_next()) != NULL) { if (toppar->get_topic() == topic && toppar->get_partition() == partition) { flag = true; break; } } if (!flag) { KafkaToppar new_toppar; if (!new_toppar.set_topic_partition(topic, partition)) return false; new_toppar.add_record(std::move(record)); this->toppar_list.add_item(std::move(new_toppar)); } else toppar->add_record(std::move(record)); return true; } static bool check_replace_toppar(KafkaTopparList *toppar_list, KafkaToppar *toppar) { bool flag = false; toppar_list->rewind(); KafkaToppar *exist; while ((exist = toppar_list->get_next()) != NULL) { if (strcmp(exist->get_topic(), toppar->get_topic()) == 0 && exist->get_partition() == toppar->get_partition()) { flag = true; if (toppar->get_offset() > exist->get_offset()) { toppar_list->add_item(std::move(*toppar)); toppar_list->del_cur(); delete exist; return true; } } } if (!flag) { toppar_list->add_item(std::move(*toppar)); return true; } return false; } int KafkaClientTask::arrange_toppar(int api_type) { switch(api_type) { case Kafka_Produce: return this->arrange_produce(); case Kafka_Fetch: return this->arrange_fetch(); case Kafka_ListOffsets: return this->arrange_offset(); case Kafka_OffsetCommit: return this->arrange_commit(); default: return 0; } } bool KafkaClientTask::add_offset_toppar(const protocol::KafkaToppar& toppar) { if (!add_topic(toppar.get_topic())) return false; KafkaToppar *exist; bool found = false; while ((exist = this->toppar_list.get_next()) != NULL) { if (strcmp(exist->get_topic(), toppar.get_topic()) == 0 && exist->get_partition() == toppar.get_partition()) { found = true; break; } } if (!found) { KafkaToppar toppar_t; toppar_t.set_topic_partition(toppar.get_topic(), toppar.get_partition()); this->toppar_list.add_item(std::move(toppar_t)); } return true; } int KafkaClientTask::arrange_offset() { this->toppar_list.rewind(); KafkaToppar *toppar; while ((toppar = this->toppar_list.get_next()) != NULL) { int node_id = get_node_id(toppar); if (node_id < 0) return -1; if (this->toppar_list_map.find(node_id) == this->toppar_list_map.end()) this->toppar_list_map[node_id] = (KafkaTopparList()); KafkaToppar new_toppar; if (!new_toppar.set_topic_partition(toppar->get_topic(), toppar->get_partition())) return -1; this->toppar_list_map[node_id].add_item(std::move(new_toppar)); } return 0; } int KafkaClientTask::arrange_commit() { this->toppar_list.rewind(); KafkaTopparList new_toppar_list; KafkaToppar *toppar; while ((toppar = this->toppar_list.get_next()) != NULL) { check_replace_toppar(&new_toppar_list, toppar); } this->toppar_list = std::move(new_toppar_list); return 0; } int KafkaClientTask::arrange_fetch() { this->meta_list.rewind(); for (auto& topic : topic_set) { if (this->member->cgroup.get_group()) { this->member->cgroup.assigned_toppar_rewind(); KafkaToppar *toppar; while ((toppar = this->member->cgroup.get_assigned_toppar_next()) != NULL) { if (topic.compare(toppar->get_topic()) == 0) { int node_id = get_node_id(toppar); if (node_id < 0) return -1; if (this->toppar_list_map.find(node_id) == this->toppar_list_map.end()) this->toppar_list_map[node_id] = (KafkaTopparList()); KafkaToppar new_toppar; if (!new_toppar.set_topic_partition(toppar->get_topic(), toppar->get_partition())) return -1; new_toppar.set_offset(toppar->get_offset()); new_toppar.set_low_watermark(toppar->get_low_watermark()); new_toppar.set_high_watermark(toppar->get_high_watermark()); this->toppar_list_map[node_id].add_item(std::move(new_toppar)); } } } else { this->toppar_list.rewind(); KafkaToppar *toppar; while ((toppar = this->toppar_list.get_next()) != NULL) { if (topic.compare(toppar->get_topic()) == 0) { int node_id = get_node_id(toppar); if (node_id < 0) return -1; if (this->toppar_list_map.find(node_id) == this->toppar_list_map.end()) this->toppar_list_map[node_id] = KafkaTopparList(); KafkaToppar new_toppar; if (!new_toppar.set_topic_partition(toppar->get_topic(), toppar->get_partition())) return -1; new_toppar.set_offset(toppar->get_offset()); new_toppar.set_offset_timestamp(toppar->get_offset_timestamp()); new_toppar.set_low_watermark(toppar->get_low_watermark()); new_toppar.set_high_watermark(toppar->get_high_watermark()); this->toppar_list_map[node_id].add_item(std::move(new_toppar)); } } } } return 0; } int KafkaClientTask::arrange_produce() { this->toppar_list.rewind(); KafkaToppar *toppar; while ((toppar = this->toppar_list.get_next()) != NULL) { if (toppar->get_partition() < 0) { toppar->record_rewind(); KafkaRecord *record; while ((record = toppar->get_record_next()) != NULL) { int partition_num; const KafkaMeta *meta; meta = get_meta(toppar->get_topic(), &this->member->meta_list); if (!meta) return -1; partition_num = meta->get_partition_elements(); if (partition_num <= 0) return -1; int partition = -1; if (this->partitioner) { const void *key; size_t key_len; record->get_key(&key, &key_len); partition = this->partitioner(toppar->get_topic(), key, key_len, partition_num); } else partition = rand() % partition_num; KafkaToppar *new_toppar = get_toppar(toppar->get_topic(), partition, &this->toppar_list); if (!new_toppar) { KafkaToppar tmp; if (!tmp.set_topic_partition(toppar->get_topic(), partition)) return -1; new_toppar = this->toppar_list.add_item(std::move(tmp)); } record->get_raw_ptr()->toppar = new_toppar->get_raw_ptr(); new_toppar->add_record(std::move(*record)); toppar->del_record_cur(); delete record; } this->toppar_list.del_cur(); delete toppar; } else { KafkaRecord *record; while ((record = toppar->get_record_next()) != NULL) record->get_raw_ptr()->toppar = toppar->get_raw_ptr(); } } this->toppar_list.rewind(); KafkaTopparList toppar_list; while ((toppar = this->toppar_list.get_next()) != NULL) { int node_id = get_node_id(toppar); if (node_id < 0) return -1; if (this->toppar_list_map.find(node_id) == this->toppar_list_map.end()) this->toppar_list_map[node_id] = KafkaTopparList(); this->toppar_list_map[node_id].add_item(std::move(*toppar)); } return 0; } SubTask *WFKafkaTask::done() { SeriesWork *series = series_of(this); auto cb = [] (WFTimerTask *task) { WFKafkaTask *kafka_task = (WFKafkaTask *)task->user_data; if (kafka_task->callback) kafka_task->callback(kafka_task); delete kafka_task; }; if (finish) { if (this->state == WFT_STATE_TASK_ERROR) { WFTimerTask *timer; timer = WFTaskFactory::create_timer_task(0, 0, std::move(cb)); timer->user_data = this; series->push_front(timer); } else { if (this->callback) this->callback(this); delete this; } } return series->pop(); } int WFKafkaClient::init(const std::string& broker, SSL_CTX *ssl_ctx) { std::vector broker_hosts; std::string::size_type ppos = 0; std::string::size_type pos; bool use_ssl; use_ssl = (strncasecmp(broker.c_str(), "kafkas://", 9) == 0); while (1) { pos = broker.find(',', ppos); std::string host = broker.substr(ppos, pos - ppos); if (use_ssl) { if (strncasecmp(host.c_str(), "kafkas://", 9) != 0) { errno = EINVAL; return -1; } } else if (strncasecmp(host.c_str(), "kafka://", 8) != 0) { if (strncasecmp(host.c_str(), "kafkas://", 9) == 0) { errno = EINVAL; return -1; } host = "kafka://" + host; } broker_hosts.emplace_back(host); if (pos == std::string::npos) break; ppos = pos + 1; } this->member = new KafkaMember; this->member->broker_hosts = std::move(broker_hosts); this->member->ssl_ctx = ssl_ctx; if (use_ssl) { this->member->transport_type = TT_TCP_SSL; this->member->scheme = "kafkas://"; } return 0; } int WFKafkaClient::init(const std::string& broker, const std::string& group, SSL_CTX *ssl_ctx) { if (this->init(broker, ssl_ctx) < 0) return -1; this->member->cgroup.set_group(group); this->member->cgroup_status = KAFKA_CGROUP_UNINIT; return 0; } int WFKafkaClient::deinit() { this->member->mutex.lock(); this->member->client_deinit = true; this->member->mutex.unlock(); this->member->decref(); return 0; } WFKafkaTask *WFKafkaClient::create_kafka_task(const std::string& query, int retry_max, kafka_callback_t cb) { WFKafkaTask *task = new KafkaClientTask(query, retry_max, std::move(cb), this); return task; } WFKafkaTask *WFKafkaClient::create_kafka_task(int retry_max, kafka_callback_t cb) { WFKafkaTask *task = new KafkaClientTask("", retry_max, std::move(cb), this); return task; } WFKafkaTask *WFKafkaClient::create_leavegroup_task(int retry_max, kafka_callback_t cb) { WFKafkaTask *task = new KafkaClientTask("api=leavegroup", retry_max, std::move(cb), this); return task; } void WFKafkaClient::set_config(protocol::KafkaConfig conf) { this->member->config = std::move(conf); } KafkaMetaList *WFKafkaClient::get_meta_list() { return &this->member->meta_list; } workflow-0.11.8/src/client/WFKafkaClient.h000066400000000000000000000112261476003635400203140ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wang Zhulei (wangzhulei@sogou-inc.com) */ #ifndef _WFKAFKACLIENT_H_ #define _WFKAFKACLIENT_H_ #include #include #include #include #include "WFTask.h" #include "KafkaMessage.h" #include "KafkaResult.h" class WFKafkaTask; class WFKafkaClient; using kafka_callback_t = std::function; using kafka_partitioner_t = std::function; class WFKafkaTask : public WFGenericTask { public: virtual bool add_topic(const std::string& topic) = 0; virtual bool add_toppar(const protocol::KafkaToppar& toppar) = 0; virtual bool add_produce_record(const std::string& topic, int partition, protocol::KafkaRecord record) = 0; virtual bool add_offset_toppar(const protocol::KafkaToppar& toppar) = 0; void add_commit_record(const protocol::KafkaRecord& record) { protocol::KafkaToppar toppar; toppar.set_topic_partition(record.get_topic(), record.get_partition()); toppar.set_offset(record.get_offset()); toppar.set_error(0); this->toppar_list.add_item(std::move(toppar)); } void add_commit_toppar(const protocol::KafkaToppar& toppar) { protocol::KafkaToppar toppar_t; toppar_t.set_topic_partition(toppar.get_topic(), toppar.get_partition()); toppar_t.set_offset(toppar.get_offset()); toppar_t.set_error(0); this->toppar_list.add_item(std::move(toppar_t)); } void add_commit_item(const std::string& topic, int partition, long long offset) { protocol::KafkaToppar toppar; toppar.set_topic_partition(topic, partition); toppar.set_offset(offset); toppar.set_error(0); this->toppar_list.add_item(std::move(toppar)); } void set_api_type(int api_type) { this->api_type = api_type; } int get_api_type() const { return this->api_type; } void set_config(protocol::KafkaConfig conf) { this->config = std::move(conf); } void set_partitioner(kafka_partitioner_t partitioner) { this->partitioner = std::move(partitioner); } protocol::KafkaResult *get_result() { return &this->result; } int get_kafka_error() const { return this->kafka_error; } void set_callback(kafka_callback_t cb) { this->callback = std::move(cb); } protected: WFKafkaTask(int retry_max, kafka_callback_t&& cb) { this->callback = std::move(cb); this->retry_max = retry_max; this->finish = false; } virtual ~WFKafkaTask() { } virtual SubTask *done(); protected: protocol::KafkaConfig config; protocol::KafkaTopparList toppar_list; protocol::KafkaMetaList meta_list; protocol::KafkaResult result; kafka_callback_t callback; kafka_partitioner_t partitioner; int api_type; int kafka_error; int retry_max; bool finish; private: friend class WFKafkaClient; }; class WFKafkaClient { public: // example: kafka://10.160.23.23:9000 // example: kafka://kafka.sogou // example: kafka.sogou:9090 // example: kafka://10.160.23.23:9000,10.123.23.23,kafka://kafka.sogou // example: kafkas://kafka.sogou -> kafka over TLS int init(const std::string& broker_url) { return this->init(broker_url, NULL); } int init(const std::string& broker_url, const std::string& group) { return this->init(broker_url, group, NULL); } // With a specific SSL_CTX. Effective only on brokers over TLS. int init(const std::string& broker_url, SSL_CTX *ssl_ctx); int init(const std::string& broker_url, const std::string& group, SSL_CTX *ssl_ctx); int deinit(); // example: topic=xxx&topic=yyy&api=fetch // example: api=commit WFKafkaTask *create_kafka_task(const std::string& query, int retry_max, kafka_callback_t cb); WFKafkaTask *create_kafka_task(int retry_max, kafka_callback_t cb); void set_config(protocol::KafkaConfig conf); public: /* If you don't leavegroup manually, rebalance would be triggered */ WFKafkaTask *create_leavegroup_task(int retry_max, kafka_callback_t callback); public: protocol::KafkaMetaList *get_meta_list(); protocol::KafkaBrokerList *get_broker_list(); private: class KafkaMember *member; friend class KafkaClientTask; }; #endif workflow-0.11.8/src/client/WFMySQLConnection.cc000066400000000000000000000024161476003635400212640ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include #include "URIParser.h" #include "WFMySQLConnection.h" int WFMySQLConnection::init(const std::string& url, SSL_CTX *ssl_ctx) { std::string query; ParsedURI uri; if (URIParser::parse(url, uri) >= 0) { if (uri.query) { query = uri.query; query += '&'; } query += "transaction=INTERNAL_CONN_ID_" + std::to_string(this->id); free(uri.query); uri.query = strdup(query.c_str()); if (uri.query) { this->uri = std::move(uri); this->ssl_ctx = ssl_ctx; return 0; } } else if (uri.state == URI_STATE_INVALID) errno = EINVAL; return -1; } workflow-0.11.8/src/client/WFMySQLConnection.h000066400000000000000000000047171476003635400211340ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _WFMYSQLCONNECTION_H_ #define _WFMYSQLCONNECTION_H_ #include #include #include #include #include "URIParser.h" #include "WFTaskFactory.h" class WFMySQLConnection { public: /* example: mysql://username:passwd@127.0.0.1/dbname?character_set=utf8 * IP string is recommmended in url. When using a domain name, the first * address resovled will be used. Don't use upstream name as a host. */ int init(const std::string& url) { return this->init(url, NULL); } int init(const std::string& url, SSL_CTX *ssl_ctx); void deinit() { } public: WFMySQLTask *create_query_task(const std::string& query, mysql_callback_t callback) { WFMySQLTask *task = WFTaskFactory::create_mysql_task(this->uri, 0, std::move(callback)); this->set_ssl_ctx(task); task->get_req()->set_query(query); return task; } /* If you don't disconnect manually, the TCP connection will be * kept alive after this object is deleted, and maybe reused by * another WFMySQLConnection object with same id and url. */ WFMySQLTask *create_disconnect_task(mysql_callback_t callback) { WFMySQLTask *task = this->create_query_task("", std::move(callback)); this->set_ssl_ctx(task); task->set_keep_alive(0); return task; } protected: void set_ssl_ctx(WFMySQLTask *task) const { using MySQLRequest = protocol::MySQLRequest; using MySQLResponse = protocol::MySQLResponse; auto *t = (WFComplexClientTask *)task; /* 'ssl_ctx' can be NULL and will use default. */ t->set_ssl_ctx(this->ssl_ctx); } protected: ParsedURI uri; SSL_CTX *ssl_ctx; int id; public: /* Make sure that concurrent connections have different id. * When a connection object is deleted, id can be reused. */ WFMySQLConnection(int id) { this->id = id; } virtual ~WFMySQLConnection() { } }; #endif workflow-0.11.8/src/client/WFRedisSubscriber.cc000066400000000000000000000061171476003635400213730ustar00rootroot00000000000000/* Copyright (c) 2024 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include "URIParser.h" #include "RedisTaskImpl.inl" #include "WFRedisSubscriber.h" int WFRedisSubscribeTask::sync_send(const std::string& command, const std::vector& params) { std::string str("*" + std::to_string(1 + params.size()) + "\r\n"); int ret; str += "$" + std::to_string(command.size()) + "\r\n" + command + "\r\n"; for (const std::string& p : params) str += "$" + std::to_string(p.size()) + "\r\n" + p + "\r\n"; this->mutex.lock(); if (this->task) { ret = this->task->push(str.c_str(), str.size()); if (ret == (int)str.size()) ret = 0; else { if (ret >= 0) errno = ENOBUFS; ret = -1; } } else { errno = ENOENT; ret = -1; } this->mutex.unlock(); return ret; } void WFRedisSubscribeTask::task_extract(WFRedisTask *task) { auto *t = (WFRedisSubscribeTask *)task->user_data; if (t->extract) t->extract(t); } void WFRedisSubscribeTask::task_callback(WFRedisTask *task) { auto *t = (WFRedisSubscribeTask *)task->user_data; t->mutex.lock(); t->task = NULL; t->mutex.unlock(); t->state = task->get_state(); t->error = task->get_error(); if (t->callback) t->callback(t); t->release(); } int WFRedisSubscriber::init(const std::string& url, SSL_CTX *ssl_ctx) { if (URIParser::parse(url, this->uri) >= 0) { this->ssl_ctx = ssl_ctx; return 0; } if (this->uri.state == URI_STATE_INVALID) errno = EINVAL; return -1; } WFRedisTask * WFRedisSubscriber::create_redis_task(const std::string& command, const std::vector& params) { WFRedisTask *task = __WFRedisTaskFactory::create_subscribe_task(this->uri, WFRedisSubscribeTask::task_extract, WFRedisSubscribeTask::task_callback); this->set_ssl_ctx(task); task->get_req()->set_request(command, params); return task; } WFRedisSubscribeTask * WFRedisSubscriber::create_subscribe_task( const std::vector& channels, extract_t extract, callback_t callback) { WFRedisTask *task = this->create_redis_task("SUBSCRIBE", channels); return new WFRedisSubscribeTask(task, std::move(extract), std::move(callback)); } WFRedisSubscribeTask * WFRedisSubscriber::create_psubscribe_task( const std::vector& patterns, extract_t extract, callback_t callback) { WFRedisTask *task = this->create_redis_task("PSUBSCRIBE", patterns); return new WFRedisSubscribeTask(task, std::move(extract), std::move(callback)); } workflow-0.11.8/src/client/WFRedisSubscriber.h000066400000000000000000000127611476003635400212370ustar00rootroot00000000000000/* Copyright (c) 2024 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _WFREDISSUBSCRIBER_H_ #define _WFREDISSUBSCRIBER_H_ #include #include #include #include #include #include #include #include #include "WFTask.h" #include "WFTaskFactory.h" class WFRedisSubscribeTask : public WFGenericTask { public: /* Note: Call 'get_resp()' only in the 'extract' function or before the task is started to set response size limit. */ protocol::RedisResponse *get_resp() { return this->task->get_resp(); } public: /* User needs to call 'release()' exactly once, anywhere. */ void release() { if (this->flag.exchange(true)) delete this; } public: /* Note: After 'release()' is called, all the requesting functions should not be called except in 'extract', because the task point may have been deleted because 'callback' finished. */ int subscribe(const std::vector& channels) { return this->sync_send("SUBSCRIBE", channels); } int unsubscribe(const std::vector& channels) { return this->sync_send("UNSUBSCRIBE", channels); } int unsubscribe() { return this->sync_send("UNSUBSCRIBE", { }); } int psubscribe(const std::vector& patterns) { return this->sync_send("PSUBSCRIBE", patterns); } int punsubscribe(const std::vector& patterns) { return this->sync_send("PUNSUBSCRIBE", patterns); } int punsubscribe() { return this->sync_send("PUNSUBSCRIBE", { }); } int ping(const std::string& message) { return this->sync_send("PING", { message }); } int ping() { return this->sync_send("PING", { }); } int quit() { return this->sync_send("QUIT", { }); } public: /* All 'timeout' proxy functions can only be called only before the task is started or in 'extract'. */ /* Timeout of waiting for each message. Very useful. If not set, the max waiting time will be the global 'response_timeout'*/ void set_watch_timeout(int timeout) { this->task->set_watch_timeout(timeout); } /* Timeout of receiving a complete message. */ void set_recv_timeout(int timeout) { this->task->set_receive_timeout(timeout); } /* Timeout of sending the first subscribe request. */ void set_send_timeout(int timeout) { this->task->set_send_timeout(timeout); } /* The default keep alive timeout is 0. If you want to keep the connection alive, make sure not to send any request after all channels/patterns were unsubscribed. */ void set_keep_alive(int timeout) { this->task->set_keep_alive(timeout); } public: /* Call 'set_extract' or 'set_callback' only before the task is started, or in 'extract'. */ void set_extract(std::function ex) { this->extract = std::move(ex); } void set_callback(std::function cb) { this->callback = std::move(cb); } protected: virtual void dispatch() { series_of(this)->push_front(this->task); this->subtask_done(); } virtual SubTask *done() { return series_of(this)->pop(); } protected: int sync_send(const std::string& command, const std::vector& params); static void task_extract(WFRedisTask *task); static void task_callback(WFRedisTask *task); protected: WFRedisTask *task; std::mutex mutex; std::atomic flag; std::function extract; std::function callback; protected: WFRedisSubscribeTask(WFRedisTask *task, std::function&& ex, std::function&& cb) : flag(false), extract(std::move(ex)), callback(std::move(cb)) { task->user_data = this; this->task = task; } virtual ~WFRedisSubscribeTask() { if (this->task) this->task->dismiss(); } friend class WFRedisSubscriber; }; class WFRedisSubscriber { public: int init(const std::string& url) { return this->init(url, NULL); } int init(const std::string& url, SSL_CTX *ssl_ctx); void deinit() { } public: using extract_t = std::function; using callback_t = std::function; public: WFRedisSubscribeTask * create_subscribe_task(const std::vector& channels, extract_t extract, callback_t callback); WFRedisSubscribeTask * create_psubscribe_task(const std::vector& patterns, extract_t extract, callback_t callback); protected: void set_ssl_ctx(WFRedisTask *task) const { using RedisRequest = protocol::RedisRequest; using RedisResponse = protocol::RedisResponse; auto *t = (WFComplexClientTask *)task; /* 'ssl_ctx' can be NULL and will use default. */ t->set_ssl_ctx(this->ssl_ctx); } protected: WFRedisTask *create_redis_task(const std::string& command, const std::vector& params); protected: ParsedURI uri; SSL_CTX *ssl_ctx; public: virtual ~WFRedisSubscriber() { } }; #endif workflow-0.11.8/src/client/xmake.lua000066400000000000000000000011411476003635400173350ustar00rootroot00000000000000target("client") set_kind("object") add_files("*.cc") remove_files("WFKafkaClient.cc") if not has_config("redis") then remove_files("WFRedisSubscriber.cc") end if not has_config("mysql") then remove_files("WFMySQLConnection.cc") end if not has_config("consul") then remove_files("WFConsulClient.cc") end target("kafka_client") if has_config("kafka") then add_files("WFKafkaClient.cc") set_kind("object") add_deps("client") add_packages("zlib", "snappy", "zstd", "lz4") else set_kind("phony") end workflow-0.11.8/src/factory/000077500000000000000000000000001476003635400157215ustar00rootroot00000000000000workflow-0.11.8/src/factory/CMakeLists.txt000066400000000000000000000011171476003635400204610ustar00rootroot00000000000000cmake_minimum_required(VERSION 3.6) project(factory) set(SRC WFGraphTask.cc DnsTaskImpl.cc WFTaskFactory.cc Workflow.cc HttpTaskImpl.cc WFResourcePool.cc WFMessageQueue.cc FileTaskImpl.cc ) if (NOT MYSQL STREQUAL "n") set(SRC ${SRC} MySQLTaskImpl.cc ) endif () if (NOT REDIS STREQUAL "n") set(SRC ${SRC} RedisTaskImpl.cc ) endif () add_library(${PROJECT_NAME} OBJECT ${SRC}) if (KAFKA STREQUAL "y") set(SRC KafkaTaskImpl.cc ) add_library("factory_kafka" OBJECT ${SRC}) set_property(SOURCE KafkaTaskImpl.cc APPEND PROPERTY COMPILE_OPTIONS "-fno-rtti") endif () workflow-0.11.8/src/factory/DnsTaskImpl.cc000066400000000000000000000135721476003635400204310ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Liu Kai (liukaidx@sogou-inc.com) */ #include #include #include "DnsMessage.h" #include "WFTaskError.h" #include "WFTaskFactory.h" #include "WFServer.h" using namespace protocol; #define DNS_KEEPALIVE_DEFAULT (60 * 1000) /**********Client**********/ class ComplexDnsTask : public WFComplexClientTask> { static struct addrinfo hints; static std::atomic seq; public: ComplexDnsTask(int retry_max, dns_callback_t&& cb): WFComplexClientTask(retry_max, std::move(cb)) { this->set_transport_type(TT_UDP); } protected: virtual CommMessageOut *message_out(); virtual bool init_success(); virtual bool finish_once(); private: bool need_redirect(); }; struct addrinfo ComplexDnsTask::hints = { .ai_flags = AI_NUMERICSERV | AI_NUMERICHOST, .ai_family = AF_UNSPEC, .ai_socktype = SOCK_STREAM }; std::atomic ComplexDnsTask::seq(0); CommMessageOut *ComplexDnsTask::message_out() { DnsRequest *req = this->get_req(); DnsResponse *resp = this->get_resp(); enum TransportType type = this->get_transport_type(); if (req->get_id() == 0) req->set_id(++ComplexDnsTask::seq * 99991 % 65535 + 1); resp->set_request_id(req->get_id()); resp->set_request_name(req->get_question_name()); req->set_single_packet(type == TT_UDP); resp->set_single_packet(type == TT_UDP); return this->WFClientTask::message_out(); } bool ComplexDnsTask::init_success() { if (uri_.scheme && strcasecmp(uri_.scheme, "dnss") == 0) this->WFComplexClientTask::set_transport_type(TT_TCP_SSL); else if (!uri_.scheme || strcasecmp(uri_.scheme, "dns") != 0) { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_URI_SCHEME_INVALID; return false; } if (!this->route_result_.request_object) { enum TransportType type = this->get_transport_type(); struct addrinfo *addr; int ret; ret = getaddrinfo(uri_.host, uri_.port, &hints, &addr); if (ret != 0) { this->state = WFT_STATE_DNS_ERROR; this->error = ret; return false; } auto *ep = &WFGlobal::get_global_settings()->dns_server_params; ret = WFGlobal::get_route_manager()->get(type, addr, info_, ep, uri_.host, ssl_ctx_, route_result_); freeaddrinfo(addr); if (ret < 0) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; return false; } } return true; } bool ComplexDnsTask::finish_once() { if (this->state == WFT_STATE_SUCCESS) { if (need_redirect()) this->set_redirect(uri_); else if (this->state != WFT_STATE_SUCCESS) this->disable_retry(); } /* If retry times meet retry max and there is no redirect, * we ask the client for a retry or redirect. */ if (retry_times_ == retry_max_ && !redirect_ && *this->get_mutable_ctx()) { /* Reset type to UDP before a client redirect. */ this->set_transport_type(TT_UDP); (*this->get_mutable_ctx())(this); } return true; } bool ComplexDnsTask::need_redirect() { DnsResponse *client_resp = this->get_resp(); enum TransportType type = this->get_transport_type(); if (type == TT_UDP && client_resp->get_tc() == 1) { this->set_transport_type(TT_TCP); return true; } return false; } /**********Client Factory**********/ WFDnsTask *WFTaskFactory::create_dns_task(const std::string& url, int retry_max, dns_callback_t callback) { ParsedURI uri; URIParser::parse(url, uri); return WFTaskFactory::create_dns_task(uri, retry_max, std::move(callback)); } WFDnsTask *WFTaskFactory::create_dns_task(const ParsedURI& uri, int retry_max, dns_callback_t callback) { ComplexDnsTask *task = new ComplexDnsTask(retry_max, std::move(callback)); const char *name; if (uri.path && uri.path[0] && uri.path[1]) name = uri.path + 1; else name = "."; DnsRequest *req = task->get_req(); req->set_question(name, DNS_TYPE_A, DNS_CLASS_IN); task->init(uri); task->set_keep_alive(DNS_KEEPALIVE_DEFAULT); return task; } /**********Server**********/ class WFDnsServerTask : public WFServerTask { public: WFDnsServerTask(CommService *service, std::function& proc) : WFServerTask(service, WFGlobal::get_scheduler(), proc) { this->type = ((WFServerBase *)service)->get_params()->transport_type; } protected: virtual CommMessageIn *message_in() { this->get_req()->set_single_packet(this->type == TT_UDP); return this->WFServerTask::message_in(); } virtual CommMessageOut *message_out() { this->get_resp()->set_single_packet(this->type == TT_UDP); return this->WFServerTask::message_out(); } virtual void handle(int state, int error); protected: enum TransportType type; }; void WFDnsServerTask::handle(int state, int error) { if (state == WFT_STATE_TOREPLY) { DnsRequest *req = this->get_req(); DnsResponse *resp = this->get_resp(); resp->set_question_name(req->get_question_name()); resp->set_question_type(req->get_question_type()); resp->set_question_class(req->get_question_class()); resp->set_opcode(req->get_opcode()); resp->set_id(req->get_id()); resp->set_rd(req->get_rd()); resp->set_qr(1); } return WFServerTask::handle(state, error); } /**********Server Factory**********/ WFDnsTask *WFServerTaskFactory::create_dns_task(CommService *service, std::function& proc) { return new WFDnsServerTask(service, proc); } workflow-0.11.8/src/factory/FileTaskImpl.cc000066400000000000000000000215771476003635400205700ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) Li Jinghao (lijinghao@sogou-inc.com) */ #include #include #include #include "WFGlobal.h" #include "WFTaskFactory.h" class WFFilepreadTask : public WFFileIOTask { public: WFFilepreadTask(int fd, void *buf, size_t count, off_t offset, IOService *service, fio_callback_t&& cb) : WFFileIOTask(service, std::move(cb)) { this->args.fd = fd; this->args.buf = buf; this->args.count = count; this->args.offset = offset; } protected: virtual int prepare() { this->prep_pread(this->args.fd, this->args.buf, this->args.count, this->args.offset); return 0; } }; class WFFilepwriteTask : public WFFileIOTask { public: WFFilepwriteTask(int fd, const void *buf, size_t count, off_t offset, IOService *service, fio_callback_t&& cb) : WFFileIOTask(service, std::move(cb)) { this->args.fd = fd; this->args.buf = (void *)buf; this->args.count = count; this->args.offset = offset; } protected: virtual int prepare() { this->prep_pwrite(this->args.fd, this->args.buf, this->args.count, this->args.offset); return 0; } }; class WFFilepreadvTask : public WFFileVIOTask { public: WFFilepreadvTask(int fd, const struct iovec *iov, int iovcnt, off_t offset, IOService *service, fvio_callback_t&& cb) : WFFileVIOTask(service, std::move(cb)) { this->args.fd = fd; this->args.iov = iov; this->args.iovcnt = iovcnt; this->args.offset = offset; } protected: virtual int prepare() { this->prep_preadv(this->args.fd, this->args.iov, this->args.iovcnt, this->args.offset); return 0; } }; class WFFilepwritevTask : public WFFileVIOTask { public: WFFilepwritevTask(int fd, const struct iovec *iov, int iovcnt, off_t offset, IOService *service, fvio_callback_t&& cb) : WFFileVIOTask(service, std::move(cb)) { this->args.fd = fd; this->args.iov = iov; this->args.iovcnt = iovcnt; this->args.offset = offset; } protected: virtual int prepare() { this->prep_pwritev(this->args.fd, this->args.iov, this->args.iovcnt, this->args.offset); return 0; } }; class WFFilefsyncTask : public WFFileSyncTask { public: WFFilefsyncTask(int fd, IOService *service, fsync_callback_t&& cb) : WFFileSyncTask(service, std::move(cb)) { this->args.fd = fd; } protected: virtual int prepare() { this->prep_fsync(this->args.fd); return 0; } }; class WFFilefdsyncTask : public WFFileSyncTask { public: WFFilefdsyncTask(int fd, IOService *service, fsync_callback_t&& cb) : WFFileSyncTask(service, std::move(cb)) { this->args.fd = fd; } protected: virtual int prepare() { this->prep_fdsync(this->args.fd); return 0; } }; /* File tasks created with path name. */ class __WFFilepreadTask : public WFFilepreadTask { public: __WFFilepreadTask(const std::string& path, void *buf, size_t count, off_t offset, IOService *service, fio_callback_t&& cb): WFFilepreadTask(-1, buf, count, offset, service, std::move(cb)), pathname(path) { } protected: virtual int prepare() { this->args.fd = open(this->pathname.c_str(), O_RDONLY); if (this->args.fd < 0) return -1; return WFFilepreadTask::prepare(); } virtual SubTask *done() { if (this->args.fd >= 0) { close(this->args.fd); this->args.fd = -1; } return WFFilepreadTask::done(); } protected: std::string pathname; }; class __WFFilepwriteTask : public WFFilepwriteTask { public: __WFFilepwriteTask(const std::string& path, const void *buf, size_t count, off_t offset, IOService *service, fio_callback_t&& cb): WFFilepwriteTask(-1, buf, count, offset, service, std::move(cb)), pathname(path) { } protected: virtual int prepare() { this->args.fd = open(this->pathname.c_str(), O_WRONLY | O_CREAT, 0644); if (this->args.fd < 0) return -1; return WFFilepwriteTask::prepare(); } virtual SubTask *done() { if (this->args.fd >= 0) { close(this->args.fd); this->args.fd = -1; } return WFFilepwriteTask::done(); } protected: std::string pathname; }; class __WFFilepreadvTask : public WFFilepreadvTask { public: __WFFilepreadvTask(const std::string& path, const struct iovec *iov, int iovcnt, off_t offset, IOService *service, fvio_callback_t&& cb) : WFFilepreadvTask(-1, iov, iovcnt, offset, service, std::move(cb)), pathname(path) { } protected: virtual int prepare() { this->args.fd = open(this->pathname.c_str(), O_RDONLY); if (this->args.fd < 0) return -1; return WFFilepreadvTask::prepare(); } virtual SubTask *done() { if (this->args.fd >= 0) { close(this->args.fd); this->args.fd = -1; } return WFFilepreadvTask::done(); } protected: std::string pathname; }; class __WFFilepwritevTask : public WFFilepwritevTask { public: __WFFilepwritevTask(const std::string& path, const struct iovec *iov, int iovcnt, off_t offset, IOService *service, fvio_callback_t&& cb) : WFFilepwritevTask(-1, iov, iovcnt, offset, service, std::move(cb)), pathname(path) { } protected: virtual int prepare() { this->args.fd = open(this->pathname.c_str(), O_WRONLY | O_CREAT, 0644); if (this->args.fd < 0) return -1; return WFFilepwritevTask::prepare(); } protected: virtual SubTask *done() { if (this->args.fd >= 0) { close(this->args.fd); this->args.fd = -1; } return WFFilepwritevTask::done(); } protected: std::string pathname; }; /* Factory functions with fd. */ WFFileIOTask *WFTaskFactory::create_pread_task(int fd, void *buf, size_t count, off_t offset, fio_callback_t callback) { return new WFFilepreadTask(fd, buf, count, offset, WFGlobal::get_io_service(), std::move(callback)); } WFFileIOTask *WFTaskFactory::create_pwrite_task(int fd, const void *buf, size_t count, off_t offset, fio_callback_t callback) { return new WFFilepwriteTask(fd, buf, count, offset, WFGlobal::get_io_service(), std::move(callback)); } WFFileVIOTask *WFTaskFactory::create_preadv_task(int fd, const struct iovec *iovec, int iovcnt, off_t offset, fvio_callback_t callback) { return new WFFilepreadvTask(fd, iovec, iovcnt, offset, WFGlobal::get_io_service(), std::move(callback)); } WFFileVIOTask *WFTaskFactory::create_pwritev_task(int fd, const struct iovec *iovec, int iovcnt, off_t offset, fvio_callback_t callback) { return new WFFilepwritevTask(fd, iovec, iovcnt, offset, WFGlobal::get_io_service(), std::move(callback)); } WFFileSyncTask *WFTaskFactory::create_fsync_task(int fd, fsync_callback_t callback) { return new WFFilefsyncTask(fd, WFGlobal::get_io_service(), std::move(callback)); } WFFileSyncTask *WFTaskFactory::create_fdsync_task(int fd, fsync_callback_t callback) { return new WFFilefdsyncTask(fd, WFGlobal::get_io_service(), std::move(callback)); } /* Factory functions with path name. */ WFFileIOTask *WFTaskFactory::create_pread_task(const std::string& path, void *buf, size_t count, off_t offset, fio_callback_t callback) { return new __WFFilepreadTask(path, buf, count, offset, WFGlobal::get_io_service(), std::move(callback)); } WFFileIOTask *WFTaskFactory::create_pwrite_task(const std::string& path, const void *buf, size_t count, off_t offset, fio_callback_t callback) { return new __WFFilepwriteTask(path, buf, count, offset, WFGlobal::get_io_service(), std::move(callback)); } WFFileVIOTask *WFTaskFactory::create_preadv_task(const std::string& path, const struct iovec *iovec, int iovcnt, off_t offset, fvio_callback_t callback) { return new __WFFilepreadvTask(path, iovec, iovcnt, offset, WFGlobal::get_io_service(), std::move(callback)); } WFFileVIOTask *WFTaskFactory::create_pwritev_task(const std::string& path, const struct iovec *iovec, int iovcnt, off_t offset, fvio_callback_t callback) { return new __WFFilepwritevTask(path, iovec, iovcnt, offset, WFGlobal::get_io_service(), std::move(callback)); } workflow-0.11.8/src/factory/HttpTaskImpl.cc000066400000000000000000000570371476003635400206300ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) Liu Kai (liukaidx@sogou-inc.com) Xie Han (xiehan@sogou-inc.com) Li Yingxin (liyingxin@sogou-inc.com) */ #include #include #include #include #include #include #include #include "WFTaskError.h" #include "WFTaskFactory.h" #include "StringUtil.h" #include "WFGlobal.h" #include "HttpUtil.h" #include "SSLWrapper.h" using namespace protocol; #define HTTP_KEEPALIVE_DEFAULT (60 * 1000) #define HTTP_KEEPALIVE_MAX (300 * 1000) /**********Client**********/ static int __encode_auth(const char *p, std::string& auth) { size_t len = strlen(p); size_t base64_len = (len + 2) / 3 * 4; char *base64 = (char *)malloc(base64_len + 1); if (!base64) return -1; EVP_EncodeBlock((unsigned char *)base64, (const unsigned char *)p, len); auth.append("Basic "); auth.append(base64, base64_len); free(base64); return 0; } class ComplexHttpTask : public WFComplexClientTask { public: ComplexHttpTask(int redirect_max, int retry_max, http_callback_t&& callback): WFComplexClientTask(retry_max, std::move(callback)), redirect_max_(redirect_max), redirect_count_(0) { HttpRequest *client_req = this->get_req(); client_req->set_method(HttpMethodGet); client_req->set_http_version("HTTP/1.1"); } protected: virtual CommMessageOut *message_out(); virtual CommMessageIn *message_in(); virtual int keep_alive_timeout(); virtual bool init_success(); virtual void init_failed(); virtual bool finish_once(); protected: bool need_redirect(const ParsedURI& uri, ParsedURI& new_uri); bool redirect_url(HttpResponse *client_resp, const ParsedURI& uri, ParsedURI& new_uri); void set_empty_request(); void check_response(); private: int redirect_max_; int redirect_count_; }; CommMessageOut *ComplexHttpTask::message_out() { HttpRequest *req = this->get_req(); struct HttpMessageHeader header; bool is_alive; if (!req->is_chunked() && !req->has_content_length_header()) { size_t body_size = req->get_output_body_size(); const char *method = req->get_method(); if (body_size != 0 || strcmp(method, "POST") == 0 || strcmp(method, "PUT") == 0) { char buf[32]; header.name = "Content-Length"; header.name_len = strlen("Content-Length"); header.value = buf; header.value_len = sprintf(buf, "%zu", body_size); req->add_header(&header); } } if (req->has_connection_header()) is_alive = req->is_keep_alive(); else { header.name = "Connection"; header.name_len = strlen("Connection"); is_alive = (this->keep_alive_timeo != 0); if (is_alive) { header.value = "Keep-Alive"; header.value_len = strlen("Keep-Alive"); } else { header.value = "close"; header.value_len = strlen("close"); } req->add_header(&header); } if (!is_alive) this->keep_alive_timeo = 0; else if (req->has_keep_alive_header()) { HttpHeaderCursor cursor(req); //req---Connection: Keep-Alive //req---Keep-Alive: timeout=0,max=100 header.name = "Keep-Alive"; header.name_len = strlen("Keep-Alive"); header.value = NULL; header.value_len = 0; if (cursor.find(&header)) { std::string keep_alive((const char *)header.value, header.value_len); std::vector params = StringUtil::split(keep_alive, ','); for (const auto& kv : params) { std::vector arr = StringUtil::split(kv, '='); if (arr.size() < 2) arr.emplace_back("0"); std::string key = StringUtil::strip(arr[0]); std::string val = StringUtil::strip(arr[1]); if (strcasecmp(key.c_str(), "timeout") == 0) { this->keep_alive_timeo = 1000 * atoi(val.c_str()); break; } } } if ((unsigned int)this->keep_alive_timeo > HTTP_KEEPALIVE_MAX) this->keep_alive_timeo = HTTP_KEEPALIVE_MAX; } return this->WFComplexClientTask::message_out(); } CommMessageIn *ComplexHttpTask::message_in() { HttpResponse *resp = this->get_resp(); if (strcmp(this->get_req()->get_method(), HttpMethodHead) == 0) resp->parse_zero_body(); return this->WFComplexClientTask::message_in(); } int ComplexHttpTask::keep_alive_timeout() { return this->resp.is_keep_alive() ? this->keep_alive_timeo : 0; } void ComplexHttpTask::set_empty_request() { HttpRequest *client_req = this->get_req(); HttpHeaderCursor cursor(client_req); struct HttpMessageHeader header = { .name = "Host", .name_len = strlen("Host"), }; client_req->set_request_uri("/"); cursor.find_and_erase(&header); header.name = "Authorization"; header.name_len = strlen("Authorization"); cursor.find_and_erase(&header); } void ComplexHttpTask::init_failed() { this->set_empty_request(); } bool ComplexHttpTask::init_success() { HttpRequest *client_req = this->get_req(); std::string request_uri; std::string header_host; bool is_ssl; if (uri_.scheme && strcasecmp(uri_.scheme, "http") == 0) is_ssl = false; else if (uri_.scheme && strcasecmp(uri_.scheme, "https") == 0) is_ssl = true; else { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_URI_SCHEME_INVALID; return false; } //todo http+unix //https://stackoverflow.com/questions/26964595/whats-the-correct-way-to-use-a-unix-domain-socket-in-requests-framework //https://stackoverflow.com/questions/27037990/connecting-to-postgres-via-database-url-and-unix-socket-in-rails if (uri_.path && uri_.path[0]) request_uri = uri_.path; else request_uri = "/"; if (uri_.query && uri_.query[0]) { request_uri += "?"; request_uri += uri_.query; } if (uri_.host && uri_.host[0]) header_host = uri_.host; if (uri_.port && uri_.port[0]) { int port = atoi(uri_.port); if (is_ssl) { if (port != 443) { header_host += ":"; header_host += uri_.port; } } else { if (port != 80) { header_host += ":"; header_host += uri_.port; } } } this->WFComplexClientTask::set_transport_type(is_ssl ? TT_TCP_SSL : TT_TCP); client_req->set_request_uri(request_uri.c_str()); client_req->set_header_pair("Host", header_host.c_str()); if (uri_.userinfo && uri_.userinfo[0]) { std::string userinfo(uri_.userinfo); std::string http_auth; StringUtil::url_decode(userinfo); if (__encode_auth(userinfo.c_str(), http_auth) < 0) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; return false; } client_req->set_header_pair("Authorization", http_auth.c_str()); } return true; } bool ComplexHttpTask::redirect_url(HttpResponse *client_resp, const ParsedURI& uri, ParsedURI& new_uri) { if (redirect_count_ < redirect_max_) { redirect_count_++; std::string url; HttpHeaderCursor cursor(client_resp); if (!cursor.find("Location", url) || url.empty()) { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_HTTP_BAD_REDIRECT_HEADER; return false; } if (url[0] == '/') { if (url[1] != '/') { if (uri.port) url = ':' + (uri.port + url); url = "//" + (uri.host + url); } url = uri.scheme + (':' + url); } URIParser::parse(url, new_uri); return true; } return false; } bool ComplexHttpTask::need_redirect(const ParsedURI& uri, ParsedURI& new_uri) { HttpRequest *client_req = this->get_req(); HttpResponse *client_resp = this->get_resp(); const char *status_code_str = client_resp->get_status_code(); const char *method = client_req->get_method(); if (!status_code_str || !method) return false; int status_code = atoi(status_code_str); switch (status_code) { case 301: case 302: case 303: if (redirect_url(client_resp, uri, new_uri)) { if (strcasecmp(method, HttpMethodGet) != 0 && strcasecmp(method, HttpMethodHead) != 0) { client_req->set_method(HttpMethodGet); } return true; } else break; case 307: case 308: if (redirect_url(client_resp, uri, new_uri)) return true; else break; default: break; } return false; } void ComplexHttpTask::check_response() { HttpResponse *resp = this->get_resp(); resp->end_parsing(); if (this->state == WFT_STATE_SYS_ERROR && this->error == ECONNRESET) { /* Servers can end the message by closing the connection. */ if (resp->is_header_complete() && !resp->is_chunked() && !resp->has_content_length_header()) { this->state = WFT_STATE_SUCCESS; this->error = 0; } } } bool ComplexHttpTask::finish_once() { if (this->state != WFT_STATE_SUCCESS) this->check_response(); if (this->state == WFT_STATE_SUCCESS) { ParsedURI new_uri; if (this->need_redirect(uri_, new_uri)) { if (uri_.userinfo && strcasecmp(uri_.host, new_uri.host) == 0) { if (!new_uri.userinfo) { new_uri.userinfo = uri_.userinfo; uri_.userinfo = NULL; } } else if (uri_.userinfo) { HttpRequest *client_req = this->get_req(); HttpHeaderCursor cursor(client_req); struct HttpMessageHeader header = { .name = "Authorization", .name_len = strlen("Authorization") }; cursor.find_and_erase(&header); } this->set_redirect(new_uri); } else if (this->state != WFT_STATE_SUCCESS) this->disable_retry(); } return true; } /*******Proxy Client*******/ static SSL *__create_ssl(SSL_CTX *ssl_ctx) { BIO *wbio; BIO *rbio; SSL *ssl; rbio = BIO_new(BIO_s_mem()); if (rbio) { wbio = BIO_new(BIO_s_mem()); if (wbio) { ssl = SSL_new(ssl_ctx); if (ssl) { SSL_set_bio(ssl, rbio, wbio); return ssl; } BIO_free(wbio); } BIO_free(rbio); } return NULL; } class ComplexHttpProxyTask : public ComplexHttpTask { public: ComplexHttpProxyTask(int redirect_max, int retry_max, http_callback_t&& callback): ComplexHttpTask(redirect_max, retry_max, std::move(callback)), is_user_request_(true) { } void set_user_uri(ParsedURI&& uri) { user_uri_ = std::move(uri); } void set_user_uri(const ParsedURI& uri) { user_uri_ = uri; } virtual const ParsedURI *get_current_uri() const { return &user_uri_; } protected: virtual CommMessageOut *message_out(); virtual CommMessageIn *message_in(); virtual int keep_alive_timeout(); virtual int first_timeout(); virtual bool init_success(); virtual bool finish_once(); protected: virtual WFConnection *get_connection() const { WFConnection *conn = this->ComplexHttpTask::get_connection(); if (conn && is_ssl_) return (SSLConnection *)conn->get_context(); return conn; } private: struct SSLConnection : public WFConnection { SSL *ssl; SSLHandshaker handshaker; SSLWrapper wrapper; SSLConnection(SSL *ssl) : handshaker(ssl), wrapper(&wrapper, ssl) { this->ssl = ssl; } }; SSLHandshaker *get_ssl_handshaker() const { return &((SSLConnection *)this->get_connection())->handshaker; } SSLWrapper *get_ssl_wrapper(ProtocolMessage *msg) const { SSLConnection *conn = (SSLConnection *)this->get_connection(); conn->wrapper = SSLWrapper(msg, conn->ssl); return &conn->wrapper; } int init_ssl_connection(); std::string proxy_auth_; ParsedURI user_uri_; bool is_ssl_; bool is_user_request_; short state_; int error_; }; int ComplexHttpProxyTask::init_ssl_connection() { static SSL_CTX *ssl_ctx = WFGlobal::get_ssl_client_ctx(); SSL *ssl = __create_ssl(ssl_ctx_ ? ssl_ctx_ : ssl_ctx); WFConnection *conn; if (!ssl) return -1; SSL_set_tlsext_host_name(ssl, user_uri_.host); SSL_set_connect_state(ssl); conn = this->ComplexHttpTask::get_connection(); SSLConnection *ssl_conn = new SSLConnection(ssl); auto&& deleter = [] (void *ctx) { SSLConnection *ssl_conn = (SSLConnection *)ctx; SSL_free(ssl_conn->ssl); delete ssl_conn; }; conn->set_context(ssl_conn, std::move(deleter)); return 0; } CommMessageOut *ComplexHttpProxyTask::message_out() { long long seqid = this->get_seq(); if (seqid == 0) // CONNECT { HttpRequest *conn_req = new HttpRequest; std::string request_uri(user_uri_.host); request_uri += ":"; if (user_uri_.port) request_uri += user_uri_.port; else request_uri += is_ssl_ ? "443" : "80"; conn_req->set_method("CONNECT"); conn_req->set_request_uri(request_uri); conn_req->set_http_version("HTTP/1.1"); conn_req->add_header_pair("Host", request_uri.c_str()); if (!proxy_auth_.empty()) conn_req->add_header_pair("Proxy-Authorization", proxy_auth_); is_user_request_ = false; return conn_req; } else if (seqid == 1 && is_ssl_) // HANDSHAKE { is_user_request_ = false; return get_ssl_handshaker(); } auto *msg = (ProtocolMessage *)this->ComplexHttpTask::message_out(); return is_ssl_ ? get_ssl_wrapper(msg) : msg; } CommMessageIn *ComplexHttpProxyTask::message_in() { long long seqid = this->get_seq(); if (seqid == 0) { HttpResponse *conn_resp = new HttpResponse; conn_resp->parse_zero_body(); return conn_resp; } else if (seqid == 1 && is_ssl_) return get_ssl_handshaker(); auto *msg = (ProtocolMessage *)this->ComplexHttpTask::message_in(); return is_ssl_ ? get_ssl_wrapper(msg) : msg; } int ComplexHttpProxyTask::keep_alive_timeout() { long long seqid = this->get_seq(); state_ = WFT_STATE_SUCCESS; error_ = 0; if (seqid == 0) { HttpResponse *resp = this->get_resp(); const char *code_str; int status_code; *resp = std::move(*(HttpResponse *)this->get_message_in()); code_str = resp->get_status_code(); status_code = code_str ? atoi(code_str) : 0; switch (status_code) { case 200: break; case 407: this->disable_retry(); default: state_ = WFT_STATE_TASK_ERROR; error_ = WFT_ERR_HTTP_PROXY_CONNECT_FAILED; return 0; } this->clear_resp(); if (is_ssl_ && init_ssl_connection() < 0) { state_ = WFT_STATE_SYS_ERROR; error_ = errno; return 0; } return HTTP_KEEPALIVE_DEFAULT; } else if (seqid == 1 && is_ssl_) return HTTP_KEEPALIVE_DEFAULT; return this->ComplexHttpTask::keep_alive_timeout(); } int ComplexHttpProxyTask::first_timeout() { return is_user_request_ ? this->watch_timeo : 0; } bool ComplexHttpProxyTask::init_success() { if (!uri_.scheme || strcasecmp(uri_.scheme, "http") != 0) { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_URI_SCHEME_INVALID; return false; } if (user_uri_.state == URI_STATE_ERROR) { this->state = WFT_STATE_SYS_ERROR; this->error = uri_.error; return false; } else if (user_uri_.state != URI_STATE_SUCCESS) { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_URI_PARSE_FAILED; return false; } if (user_uri_.scheme && strcasecmp(user_uri_.scheme, "http") == 0) is_ssl_ = false; else if (user_uri_.scheme && strcasecmp(user_uri_.scheme, "https") == 0) is_ssl_ = true; else { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_URI_SCHEME_INVALID; return false; } int user_port; if (user_uri_.port) { user_port = atoi(user_uri_.port); if (user_port <= 0 || user_port > 65535) { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_URI_PORT_INVALID; return false; } } else user_port = is_ssl_ ? 443 : 80; std::string info("http-proxy|remote:"); info += is_ssl_ ? "https://" : "http://"; info += user_uri_.host; info += ":"; if (user_uri_.port) info += user_uri_.port; else info += is_ssl_ ? "443" : "80"; if (uri_.userinfo && uri_.userinfo[0]) { std::string userinfo(uri_.userinfo); StringUtil::url_decode(userinfo); proxy_auth_.clear(); if (__encode_auth(userinfo.c_str(), proxy_auth_) < 0) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; return false; } info += "|auth:"; info += proxy_auth_; } this->WFComplexClientTask::set_info(info); std::string request_uri; std::string header_host; if (user_uri_.path && user_uri_.path[0]) request_uri = user_uri_.path; else request_uri = "/"; if (user_uri_.query && user_uri_.query[0]) { request_uri += "?"; request_uri += user_uri_.query; } if (user_uri_.host && user_uri_.host[0]) header_host = user_uri_.host; if ((is_ssl_ && user_port != 443) || (!is_ssl_ && user_port != 80)) { header_host += ":"; header_host += uri_.port; } HttpRequest *client_req = this->get_req(); client_req->set_request_uri(request_uri.c_str()); client_req->set_header_pair("Host", header_host.c_str()); this->WFComplexClientTask::set_transport_type(TT_TCP); if (user_uri_.userinfo && user_uri_.userinfo[0]) { std::string userinfo(user_uri_.userinfo); std::string http_auth; StringUtil::url_decode(userinfo); if (__encode_auth(userinfo.c_str(), http_auth) < 0) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; return false; } client_req->set_header_pair("Authorization", http_auth.c_str()); } return true; } bool ComplexHttpProxyTask::finish_once() { if (!is_user_request_) { if (this->state == WFT_STATE_SUCCESS && state_ != WFT_STATE_SUCCESS) { this->state = state_; this->error = error_; } if (this->get_seq() == 0) { delete this->get_message_in(); delete this->get_message_out(); } is_user_request_ = true; return false; } if (this->state != WFT_STATE_SUCCESS) this->check_response(); if (this->state == WFT_STATE_SUCCESS) { ParsedURI new_uri; if (this->need_redirect(user_uri_, new_uri)) { if (user_uri_.userinfo && strcasecmp(user_uri_.host, new_uri.host) == 0) { if (!new_uri.userinfo) { new_uri.userinfo = user_uri_.userinfo; user_uri_.userinfo = NULL; } } else if (user_uri_.userinfo) { HttpRequest *client_req = this->get_req(); HttpHeaderCursor cursor(client_req); struct HttpMessageHeader header = { .name = "Authorization", .name_len = strlen("Authorization") }; cursor.find_and_erase(&header); } user_uri_ = std::move(new_uri); this->set_redirect(uri_); } else if (this->state != WFT_STATE_SUCCESS) this->disable_retry(); } return true; } /**********Client Factory**********/ WFHttpTask *WFTaskFactory::create_http_task(const std::string& url, int redirect_max, int retry_max, http_callback_t callback) { auto *task = new ComplexHttpTask(redirect_max, retry_max, std::move(callback)); ParsedURI uri; URIParser::parse(url, uri); task->init(std::move(uri)); task->set_keep_alive(HTTP_KEEPALIVE_DEFAULT); return task; } WFHttpTask *WFTaskFactory::create_http_task(const ParsedURI& uri, int redirect_max, int retry_max, http_callback_t callback) { auto *task = new ComplexHttpTask(redirect_max, retry_max, std::move(callback)); task->init(uri); task->set_keep_alive(HTTP_KEEPALIVE_DEFAULT); return task; } WFHttpTask *WFTaskFactory::create_http_task(const std::string& url, const std::string& proxy_url, int redirect_max, int retry_max, http_callback_t callback) { auto *task = new ComplexHttpProxyTask(redirect_max, retry_max, std::move(callback)); ParsedURI uri, user_uri; URIParser::parse(url, user_uri); URIParser::parse(proxy_url, uri); task->set_user_uri(std::move(user_uri)); task->set_keep_alive(HTTP_KEEPALIVE_DEFAULT); task->init(std::move(uri)); return task; } WFHttpTask *WFTaskFactory::create_http_task(const ParsedURI& uri, const ParsedURI& proxy_uri, int redirect_max, int retry_max, http_callback_t callback) { auto *task = new ComplexHttpProxyTask(redirect_max, retry_max, std::move(callback)); task->set_user_uri(uri); task->set_keep_alive(HTTP_KEEPALIVE_DEFAULT); task->init(proxy_uri); return task; } /**********Server**********/ class WFHttpServerTask : public WFServerTask { private: using TASK = WFNetworkTask; public: WFHttpServerTask(CommService *service, std::function& proc) : WFServerTask(service, WFGlobal::get_scheduler(), proc), req_is_alive_(false), req_has_keep_alive_header_(false) {} protected: virtual void handle(int state, int error); virtual CommMessageOut *message_out(); protected: bool req_is_alive_; bool req_has_keep_alive_header_; std::string req_keep_alive_; }; void WFHttpServerTask::handle(int state, int error) { if (state == WFT_STATE_TOREPLY) { req_is_alive_ = this->req.is_keep_alive(); if (req_is_alive_ && this->req.has_keep_alive_header()) { HttpHeaderCursor cursor(&this->req); struct HttpMessageHeader header = { .name = "Keep-Alive", .name_len = strlen("Keep-Alive"), }; req_has_keep_alive_header_ = cursor.find(&header); if (req_has_keep_alive_header_) { req_keep_alive_.assign((const char *)header.value, header.value_len); } } } this->WFServerTask::handle(state, error); } CommMessageOut *WFHttpServerTask::message_out() { HttpResponse *resp = this->get_resp(); struct HttpMessageHeader header; if (!resp->get_http_version()) resp->set_http_version("HTTP/1.1"); const char *status_code_str = resp->get_status_code(); if (!status_code_str || !resp->get_reason_phrase()) { int status_code; if (status_code_str) status_code = atoi(status_code_str); else status_code = HttpStatusOK; HttpUtil::set_response_status(resp, status_code); } if (!resp->is_chunked() && !resp->has_content_length_header()) { char buf[32]; header.name = "Content-Length"; header.name_len = strlen("Content-Length"); header.value = buf; header.value_len = sprintf(buf, "%zu", resp->get_output_body_size()); resp->add_header(&header); } bool is_alive; if (resp->has_connection_header()) is_alive = resp->is_keep_alive(); else is_alive = req_is_alive_; if (!is_alive) this->keep_alive_timeo = 0; else { //req---Connection: Keep-Alive //req---Keep-Alive: timeout=5,max=100 if (req_has_keep_alive_header_) { int flag = 0; std::vector params = StringUtil::split(req_keep_alive_, ','); for (const auto& kv : params) { std::vector arr = StringUtil::split(kv, '='); if (arr.size() < 2) arr.emplace_back("0"); std::string key = StringUtil::strip(arr[0]); std::string val = StringUtil::strip(arr[1]); if (!(flag & 1) && strcasecmp(key.c_str(), "timeout") == 0) { flag |= 1; // keep_alive_timeo = 5000ms when Keep-Alive: timeout=5 this->keep_alive_timeo = 1000 * atoi(val.c_str()); if (flag == 3) break; } else if (!(flag & 2) && strcasecmp(key.c_str(), "max") == 0) { flag |= 2; if (this->get_seq() >= atoi(val.c_str())) { this->keep_alive_timeo = 0; break; } if (flag == 3) break; } } } if ((unsigned int)this->keep_alive_timeo > HTTP_KEEPALIVE_MAX) this->keep_alive_timeo = HTTP_KEEPALIVE_MAX; //if (this->keep_alive_timeo < 0 || this->keep_alive_timeo > HTTP_KEEPALIVE_MAX) } if (!resp->has_connection_header()) { header.name = "Connection"; header.name_len = 10; if (this->keep_alive_timeo == 0) { header.value = "close"; header.value_len = 5; } else { header.value = "Keep-Alive"; header.value_len = 10; } resp->add_header(&header); } return this->WFServerTask::message_out(); } /**********Server Factory**********/ WFHttpTask *WFServerTaskFactory::create_http_task(CommService *service, std::function& process) { return new WFHttpServerTask(service, process); } workflow-0.11.8/src/factory/KafkaTaskImpl.cc000066400000000000000000000430721476003635400207200ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wang Zhulei (wangzhulei@sogou-inc.com) Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include #include #include #include "StringUtil.h" #include "KafkaTaskImpl.inl" using namespace protocol; #define KAFKA_KEEPALIVE_DEFAULT (60 * 1000) #define KAFKA_ROUNDTRIP_TIMEOUT (5 * 1000) static KafkaCgroup __create_cgroup(const KafkaCgroup *c) { KafkaCgroup g; const char *member_id = c->get_member_id(); if (member_id) g.set_member_id(member_id); g.set_group(c->get_group()); return g; } /**********Client**********/ class __ComplexKafkaTask : public WFComplexClientTask { public: __ComplexKafkaTask(int retry_max, __kafka_callback_t&& callback) : WFComplexClientTask(retry_max, std::move(callback)) { is_user_request_ = true; is_redirect_ = false; ctx_ = 0; } protected: virtual CommMessageOut *message_out(); virtual CommMessageIn *message_in(); virtual bool init_success(); virtual bool finish_once(); private: struct KafkaConnectionInfo { kafka_api_t api; kafka_sasl_t sasl; std::string mechanisms; KafkaConnectionInfo() { kafka_api_init(&this->api); kafka_sasl_init(&this->sasl); } ~KafkaConnectionInfo() { kafka_api_deinit(&this->api); kafka_sasl_deinit(&this->sasl); } bool init(const char *mechanisms) { this->mechanisms = mechanisms; if (strncasecmp(mechanisms, "SCRAM", 5) == 0) { if (strcasecmp(mechanisms, "SCRAM-SHA-1") == 0) { this->sasl.scram.evp = EVP_sha1(); this->sasl.scram.scram_h = SHA1; this->sasl.scram.scram_h_size = SHA_DIGEST_LENGTH; } else if (strcasecmp(mechanisms, "SCRAM-SHA-256") == 0) { this->sasl.scram.evp = EVP_sha256(); this->sasl.scram.scram_h = SHA256; this->sasl.scram.scram_h_size = SHA256_DIGEST_LENGTH; } else if (strcasecmp(mechanisms, "SCRAM-SHA-512") == 0) { this->sasl.scram.evp = EVP_sha512(); this->sasl.scram.scram_h = SHA512; this->sasl.scram.scram_h_size = SHA512_DIGEST_LENGTH; } else return false; } return true; } }; virtual int keep_alive_timeout(); virtual int first_timeout(); bool has_next(); bool process_produce(); bool process_fetch(); bool process_metadata(); bool process_list_offsets(); bool process_find_coordinator(); bool process_join_group(); bool process_sync_group(); bool process_sasl_authenticate(); bool process_sasl_handshake(); bool is_user_request_; bool is_redirect_; std::string user_info_; }; CommMessageOut *__ComplexKafkaTask::message_out() { long long seqid = this->get_seq(); if (seqid == 0) { KafkaConnectionInfo *conn_info = new KafkaConnectionInfo; this->get_req()->set_api(&conn_info->api); this->get_connection()->set_context(conn_info, [](void *ctx) { delete (KafkaConnectionInfo *)ctx; }); if (!this->get_req()->get_config()->get_broker_version()) { KafkaRequest *req = new KafkaRequest; req->duplicate(*this->get_req()); req->set_api_type(Kafka_ApiVersions); is_user_request_ = false; return req; } else { kafka_api_version_t *api; size_t api_cnt; const char *v = this->get_req()->get_config()->get_broker_version(); int ret = kafka_api_version_is_queryable(v, &api, &api_cnt); kafka_api_version_t *p = NULL; if (ret == 0) { p = (kafka_api_version_t *)malloc(api_cnt * sizeof(*p)); if (p) { memcpy(p, api, api_cnt * sizeof(kafka_api_version_t)); conn_info->api.api = p; conn_info->api.elements = api_cnt; conn_info->api.features = kafka_get_features(p, api_cnt); } } if (!p) return NULL; seqid++; } } if (seqid == 1) { const char *sasl_mech = this->get_req()->get_config()->get_sasl_mech(); KafkaConnectionInfo *conn_info = (KafkaConnectionInfo *)this->get_connection()->get_context(); if (sasl_mech && conn_info->sasl.status == 0) { if (!conn_info->init(sasl_mech)) return NULL; this->get_req()->set_api(&conn_info->api); this->get_req()->set_sasl(&conn_info->sasl); KafkaRequest *req = new KafkaRequest; req->duplicate(*this->get_req()); if (conn_info->api.features & KAFKA_FEATURE_SASL_HANDSHAKE) req->set_api_type(Kafka_SaslHandshake); else req->set_api_type(Kafka_SaslAuthenticate); req->set_correlation_id(1); is_user_request_ = false; return req; } } KafkaConnectionInfo *conn_info = (KafkaConnectionInfo *)this->get_connection()->get_context(); KafkaRequest *req = this->get_req(); req->set_api(&conn_info->api); if (req->get_api_type() == Kafka_Fetch || req->get_api_type() == Kafka_ListOffsets) { KafkaTopparList *req_toppar_lst = req->get_toppar_list(); KafkaToppar *toppar; KafkaTopparList toppar_list; bool flag = false; long long cfg_ts = req->get_config()->get_offset_timestamp(); long long tp_ts; req_toppar_lst->rewind(); while ((toppar = req_toppar_lst->get_next()) != NULL) { tp_ts = toppar->get_offset_timestamp(); if (tp_ts == KAFKA_TIMESTAMP_UNINIT) tp_ts = cfg_ts; if (toppar->get_offset() == KAFKA_OFFSET_UNINIT) { if (tp_ts == KAFKA_TIMESTAMP_EARLIEST) toppar->set_offset(toppar->get_low_watermark()); else if (tp_ts < 0) { toppar->set_offset(toppar->get_high_watermark()); tp_ts = KAFKA_TIMESTAMP_LATEST; } } else if (toppar->get_offset() == KAFKA_OFFSET_OVERFLOW) { if (tp_ts == KAFKA_TIMESTAMP_EARLIEST) toppar->set_offset(toppar->get_low_watermark()); else { toppar->set_offset(toppar->get_high_watermark()); tp_ts = KAFKA_TIMESTAMP_LATEST; } } if (toppar->get_offset() < 0) { toppar->set_offset_timestamp(tp_ts); toppar_list.add_item(*toppar); flag = true; } } if (flag) { KafkaRequest *new_req = new KafkaRequest; new_req->set_api(&conn_info->api); new_req->set_broker(*req->get_broker()); new_req->set_toppar_list(toppar_list); new_req->set_config(*req->get_config()); new_req->set_api_type(Kafka_ListOffsets); new_req->set_correlation_id(seqid); is_user_request_ = false; return new_req; } } this->get_req()->set_correlation_id(seqid); return this->WFComplexClientTask::message_out(); } CommMessageIn *__ComplexKafkaTask::message_in() { KafkaRequest *req = static_cast(this->get_message_out()); KafkaResponse *resp = this->get_resp(); KafkaCgroup *cgroup; resp->set_api_type(req->get_api_type()); resp->set_api_version(req->get_api_version()); resp->duplicate(*req); switch (req->get_api_type()) { case Kafka_FindCoordinator: case Kafka_Heartbeat: cgroup = req->get_cgroup(); if (cgroup->get_group()) resp->set_cgroup(__create_cgroup(cgroup)); break; default: break; } return this->WFComplexClientTask::message_in(); } bool __ComplexKafkaTask::init_success() { enum TransportType type; if (uri_.scheme && strcasecmp(uri_.scheme, "kafka") == 0) type = TT_TCP; else if (uri_.scheme && strcasecmp(uri_.scheme, "kafkas") == 0) type = TT_TCP_SSL; else { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_URI_SCHEME_INVALID; return false; } std::string username, password, sasl, client; if (uri_.userinfo) { const char *pos = strchr(uri_.userinfo, ':'); if (pos) { username = std::string(uri_.userinfo, pos - uri_.userinfo); StringUtil::url_decode(username); const char *pos1 = strchr(pos + 1, ':'); if (pos1) { password = std::string(pos + 1, pos1 - pos - 1); StringUtil::url_decode(password); const char *pos2 = strchr(pos1 + 1, ':'); if (pos2) { sasl = std::string(pos1 + 1, pos2 - pos1 - 1); client = std::string(pos1 + 1); } } } if (username.empty() || password.empty() || sasl.empty() || client.empty()) { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_URI_SCHEME_INVALID; return false; } user_info_ = uri_.userinfo; size_t info_len = username.size() + password.size() + sasl.size() + client.size() + 50; char *info = new char[info_len]; snprintf(info, info_len, "%s|user:%s|pass:%s|sasl:%s|client:%s|", "kafka", username.c_str(), password.c_str(), sasl.c_str(), client.c_str()); this->WFComplexClientTask::set_info(info); delete []info; } this->WFComplexClientTask::set_transport_type(type); return true; } int __ComplexKafkaTask::keep_alive_timeout() { if (this->get_resp()->get_broker()->get_error()) return 0; return this->WFComplexClientTask::keep_alive_timeout(); } int __ComplexKafkaTask::first_timeout() { KafkaRequest *client_req = this->get_req(); int ret = 0; switch(client_req->get_api_type()) { case Kafka_Fetch: ret = client_req->get_config()->get_fetch_timeout(); break; case Kafka_JoinGroup: ret = client_req->get_config()->get_session_timeout(); break; case Kafka_SyncGroup: ret = client_req->get_config()->get_rebalance_timeout(); break; case Kafka_Produce: ret = client_req->get_config()->get_produce_timeout(); break; default: return 0; } return ret + KAFKA_ROUNDTRIP_TIMEOUT; } bool __ComplexKafkaTask::process_find_coordinator() { KafkaCgroup *cgroup = this->get_resp()->get_cgroup(); ctx_ = cgroup->get_error(); if (ctx_) { this->error = WFT_ERR_KAFKA_CGROUP_FAILED; this->state = WFT_STATE_TASK_ERROR; return false; } else { this->get_req()->set_cgroup(*cgroup); KafkaBroker *coordinator = cgroup->get_coordinator(); std::string url(uri_.scheme); url += "://"; url += user_info_ + "@"; url += coordinator->get_host(); url += ":" + std::to_string(coordinator->get_port()); ParsedURI uri; URIParser::parse(url, uri); set_redirect(std::move(uri)); this->get_req()->set_api_type(Kafka_JoinGroup); is_redirect_ = true; return true; } } bool __ComplexKafkaTask::process_join_group() { KafkaResponse *msg = this->get_resp(); switch(msg->get_cgroup()->get_error()) { case KAFKA_MEMBER_ID_REQUIRED: this->get_req()->set_api_type(Kafka_JoinGroup); break; case KAFKA_UNKNOWN_MEMBER_ID: msg->get_cgroup()->set_member_id(""); this->get_req()->set_api_type(Kafka_JoinGroup); break; case 0: this->get_req()->set_api_type(Kafka_Metadata); break; default: ctx_ = msg->get_cgroup()->get_error(); this->error = WFT_ERR_KAFKA_CGROUP_FAILED; this->state = WFT_STATE_TASK_ERROR; return false; } return true; } bool __ComplexKafkaTask::process_sync_group() { ctx_ = this->get_resp()->get_cgroup()->get_error(); if (ctx_) { this->error = WFT_ERR_KAFKA_CGROUP_FAILED; this->state = WFT_STATE_TASK_ERROR; return false; } else { this->get_req()->set_api_type(Kafka_OffsetFetch); return true; } } bool __ComplexKafkaTask::process_metadata() { KafkaResponse *msg = this->get_resp(); msg->get_meta_list()->rewind(); KafkaMeta *meta; while ((meta = msg->get_meta_list()->get_next()) != NULL) { switch (meta->get_error()) { case KAFKA_LEADER_NOT_AVAILABLE: this->get_req()->set_api_type(Kafka_Metadata); return true; case 0: break; default: ctx_ = meta->get_error(); this->error = WFT_ERR_KAFKA_META_FAILED; this->state = WFT_STATE_TASK_ERROR; return false; } } this->get_req()->set_meta_list(*msg->get_meta_list()); if (msg->get_cgroup()->get_group()) { if (msg->get_cgroup()->is_leader()) { KafkaCgroup *cgroup = msg->get_cgroup(); if (cgroup->run_assignor(msg->get_meta_list(), cgroup->get_protocol_name()) < 0) { this->error = WFT_ERR_KAFKA_CGROUP_ASSIGN_FAILED; this->state = WFT_STATE_TASK_ERROR; return false; } } this->get_req()->set_api_type(Kafka_SyncGroup); return true; } return false; } bool __ComplexKafkaTask::process_fetch() { bool ret = false; KafkaToppar *toppar; this->get_resp()->get_toppar_list()->rewind(); while ((toppar = this->get_resp()->get_toppar_list()->get_next()) != NULL) { int toppar_error = toppar->get_error(); if (toppar_error == KAFKA_OFFSET_OUT_OF_RANGE) { toppar->set_offset(KAFKA_OFFSET_OVERFLOW); toppar->set_low_watermark(KAFKA_OFFSET_UNINIT); toppar->set_high_watermark(KAFKA_OFFSET_UNINIT); ret = true; } else if (toppar_error) { ctx_ = toppar_error; this->error = WFT_ERR_KAFKA_FETCH_FAILED; this->state = WFT_STATE_TASK_ERROR; return false; } } return ret; } bool __ComplexKafkaTask::process_list_offsets() { KafkaToppar *toppar; this->get_resp()->get_toppar_list()->rewind(); while ((toppar = this->get_resp()->get_toppar_list()->get_next()) != NULL) { if (toppar->get_error()) { this->error = toppar->get_error(); this->state = WFT_STATE_TASK_ERROR; } } return false; } bool __ComplexKafkaTask::process_produce() { KafkaToppar *toppar; this->get_resp()->get_toppar_list()->rewind(); while ((toppar = this->get_resp()->get_toppar_list()->get_next()) != NULL) { if (!toppar->record_reach_end()) { this->get_req()->set_api_type(Kafka_Produce); return true; } if (toppar->get_error()) { ctx_ = toppar->get_error(); this->error = WFT_ERR_KAFKA_PRODUCE_FAILED; this->state = WFT_STATE_TASK_ERROR; return false; } } return false; } bool __ComplexKafkaTask::process_sasl_handshake() { ctx_ = this->get_resp()->get_broker()->get_error(); if (ctx_) { this->error = WFT_ERR_KAFKA_SASL_DISALLOWED; this->state = WFT_STATE_TASK_ERROR; return false; } return true; } bool __ComplexKafkaTask::process_sasl_authenticate() { ctx_ = this->get_resp()->get_broker()->get_error(); if (ctx_) { this->error = WFT_ERR_KAFKA_SASL_DISALLOWED; this->state = WFT_STATE_TASK_ERROR; } return false; } bool __ComplexKafkaTask::has_next() { switch (this->get_resp()->get_api_type()) { case Kafka_Produce: return this->process_produce(); case Kafka_Fetch: return this->process_fetch(); case Kafka_Metadata: return this->process_metadata(); case Kafka_FindCoordinator: return this->process_find_coordinator(); case Kafka_JoinGroup: return this->process_join_group(); case Kafka_SyncGroup: return this->process_sync_group(); case Kafka_SaslHandshake: return this->process_sasl_handshake(); case Kafka_SaslAuthenticate: return this->process_sasl_authenticate(); case Kafka_ListOffsets: return this->process_list_offsets(); case Kafka_OffsetCommit: case Kafka_OffsetFetch: case Kafka_LeaveGroup: case Kafka_DescribeGroups: case Kafka_Heartbeat: ctx_ = this->get_resp()->get_cgroup()->get_error(); if (ctx_) { this->error = WFT_ERR_KAFKA_CGROUP_FAILED; this->state = WFT_STATE_TASK_ERROR; } break; case Kafka_ApiVersions: break; default: this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_KAFKA_API_UNKNOWN; break; } return false; } bool __ComplexKafkaTask::finish_once() { bool finish = true; if (this->state == WFT_STATE_SUCCESS) finish = !has_next(); if (!is_user_request_) { delete this->get_message_out(); this->get_resp()->clear_buf(); } if (is_redirect_ && this->state == WFT_STATE_UNDEFINED) { this->get_req()->clear_buf(); is_redirect_ = false; } else if (this->state == WFT_STATE_SUCCESS) { if (!is_user_request_) { is_user_request_ = true; return false; } if (!finish) { this->get_req()->clear_buf(); this->get_resp()->clear_buf(); return false; } } else { this->get_resp()->set_api_type(this->get_req()->get_api_type()); this->get_resp()->set_api_version(this->get_req()->get_api_version()); } is_user_request_ = true; return true; } /**********Factory**********/ // kafka://user:password:sasl@host:port/api=type&topic=name __WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(const std::string& url, SSL_CTX *ssl_ctx, int retry_max, __kafka_callback_t callback) { auto *task = new __ComplexKafkaTask(retry_max, std::move(callback)); task->set_ssl_ctx(ssl_ctx); ParsedURI uri; URIParser::parse(url, uri); task->init(std::move(uri)); task->set_keep_alive(KAFKA_KEEPALIVE_DEFAULT); return task; } __WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(const ParsedURI& uri, SSL_CTX *ssl_ctx, int retry_max, __kafka_callback_t callback) { auto *task = new __ComplexKafkaTask(retry_max, std::move(callback)); task->set_ssl_ctx(ssl_ctx); task->init(uri); task->set_keep_alive(KAFKA_KEEPALIVE_DEFAULT); return task; } __WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(enum TransportType type, const char *host, unsigned short port, SSL_CTX *ssl_ctx, const std::string& info, int retry_max, __kafka_callback_t callback) { auto *task = new __ComplexKafkaTask(retry_max, std::move(callback)); task->set_ssl_ctx(ssl_ctx); ParsedURI uri; char buf[32]; if (type == TT_TCP_SSL) uri.scheme = strdup("kafkas"); else uri.scheme = strdup("kafka"); if (!info.empty()) uri.userinfo = strdup(info.c_str()); uri.host = strdup(host); sprintf(buf, "%u", port); uri.port = strdup(buf); if (!uri.scheme || !uri.host || !uri.port || (!info.empty() && !uri.userinfo)) { uri.state = URI_STATE_ERROR; uri.error = errno; } else uri.state = URI_STATE_SUCCESS; task->init(std::move(uri)); task->set_keep_alive(KAFKA_KEEPALIVE_DEFAULT); return task; } workflow-0.11.8/src/factory/KafkaTaskImpl.inl000066400000000000000000000033021476003635400211050ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wang Zhulei (wangzhulei@sogou-inc.com) */ #include #include "WFTaskFactory.h" #include "KafkaMessage.h" // Kafka internal task. For __ComplexKafkaTask usage only using __WFKafkaTask = WFNetworkTask; using __kafka_callback_t = std::function; class __WFKafkaTaskFactory { public: /* __WFKafkaTask is create by __ComplexKafkaTask. This is an internal * interface for create internal task. It should not be created directly by common * user task. */ static __WFKafkaTask *create_kafka_task(const ParsedURI& uri, SSL_CTX *ssl_ctx, int retry_max, __kafka_callback_t callback); static __WFKafkaTask *create_kafka_task(const std::string& url, SSL_CTX *ssl_ctx, int retry_max, __kafka_callback_t callback); static __WFKafkaTask *create_kafka_task(enum TransportType type, const char *host, unsigned short port, SSL_CTX *ssl_ctx, const std::string& info, int retry_max, __kafka_callback_t callback); }; workflow-0.11.8/src/factory/MySQLTaskImpl.cc000066400000000000000000000470111476003635400206450ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com) Li Yingxin (liyingxin@sogou-inc.com) */ #include #include #include #include #include #include #include #include "WFTaskError.h" #include "WFTaskFactory.h" #include "StringUtil.h" #include "WFGlobal.h" #include "mysql_types.h" using namespace protocol; #define MYSQL_KEEPALIVE_DEFAULT (60 * 1000) #define MYSQL_KEEPALIVE_TRANSACTION (3600 * 1000) /**********Client**********/ class ComplexMySQLTask : public WFComplexClientTask { protected: virtual bool check_request(); virtual CommMessageOut *message_out(); virtual CommMessageIn *message_in(); virtual int keep_alive_timeout(); virtual int first_timeout(); virtual bool init_success(); virtual bool finish_once(); protected: virtual WFConnection *get_connection() const { WFConnection *conn = this->WFComplexClientTask::get_connection(); if (conn) { void *ctx = conn->get_context(); if (ctx) conn = (WFConnection *)ctx; } return conn; } private: enum ConnState { ST_SSL_REQUEST, ST_AUTH_REQUEST, ST_AUTH_SWITCH_REQUEST, ST_CLEAR_PASSWORD_REQUEST, ST_SHA256_PUBLIC_KEY_REQUEST, ST_CSHA2_PUBLIC_KEY_REQUEST, ST_RSA_AUTH_REQUEST, ST_CHARSET_REQUEST, ST_FIRST_USER_REQUEST, ST_USER_REQUEST }; struct MyConnection : public WFConnection { std::string str; // shared by auth, auth_swich and rsa_auth requests unsigned char seed[20]; enum ConnState state; unsigned char mysql_seqid; SSL *ssl; SSLWrapper wrapper; MyConnection(SSL *ssl) : wrapper(&wrapper, ssl) { this->ssl = ssl; } }; int check_handshake(MySQLHandshakeResponse *resp); int auth_switch(MySQLAuthResponse *resp, MyConnection *conn); struct MySSLWrapper : public SSLWrapper { MySSLWrapper(ProtocolMessage *msg, SSL *ssl) : SSLWrapper(msg, ssl) { } ProtocolMessage *get_msg() const { return this->message; } virtual ~MySSLWrapper() { delete this->message; } }; private: std::string username_; std::string password_; std::string db_; std::string res_charset_; short character_set_; short state_; int error_; bool is_ssl_; bool is_user_request_; public: ComplexMySQLTask(int retry_max, mysql_callback_t&& callback): WFComplexClientTask(retry_max, std::move(callback)), character_set_(33), is_user_request_(true) {} }; bool ComplexMySQLTask::check_request() { if (this->req.query_is_unset() == false) { if (this->req.get_command() == MYSQL_COM_QUERY) { std::string query = this->req.get_query(); if (strncasecmp(query.c_str(), "USE ", 4) && strncasecmp(query.c_str(), "SET NAMES ", 10) && strncasecmp(query.c_str(), "SET CHARSET ", 12) && strncasecmp(query.c_str(), "SET CHARACTER SET ", 18)) { return true; } } this->error = WFT_ERR_MYSQL_COMMAND_DISALLOWED; } else this->error = WFT_ERR_MYSQL_QUERY_NOT_SET; this->state = WFT_STATE_TASK_ERROR; return false; } static SSL *__create_ssl(SSL_CTX *ssl_ctx) { BIO *wbio; BIO *rbio; SSL *ssl; rbio = BIO_new(BIO_s_mem()); if (rbio) { wbio = BIO_new(BIO_s_mem()); if (wbio) { ssl = SSL_new(ssl_ctx); if (ssl) { SSL_set_bio(ssl, rbio, wbio); return ssl; } BIO_free(wbio); } BIO_free(rbio); } return NULL; } CommMessageOut *ComplexMySQLTask::message_out() { MySQLAuthSwitchRequest *auth_switch_req; MySQLRSAAuthRequest *rsa_auth_req; MySQLAuthRequest *auth_req; MySQLRequest *req; is_user_request_ = false; if (this->get_seq() == 0) return new MySQLHandshakeRequest; auto *conn = (MyConnection *)this->get_connection(); switch (conn->state) { case ST_SSL_REQUEST: req = new MySQLSSLRequest(character_set_, conn->ssl); req->set_seqid(conn->mysql_seqid); return req; case ST_AUTH_REQUEST: req = new MySQLAuthRequest; auth_req = (MySQLAuthRequest *)req; auth_req->set_auth(username_, password_, db_, character_set_); auth_req->set_auth_plugin_name(std::move(conn->str)); auth_req->set_seed(conn->seed); break; case ST_CLEAR_PASSWORD_REQUEST: conn->str = "mysql_clear_password"; case ST_AUTH_SWITCH_REQUEST: req = new MySQLAuthSwitchRequest; auth_switch_req = (MySQLAuthSwitchRequest *)req; auth_switch_req->set_password(password_); auth_switch_req->set_auth_plugin_name(std::move(conn->str)); auth_switch_req->set_seed(conn->seed); #if OPENSSL_VERSION_NUMBER < 0x10100000L WFGlobal::get_ssl_client_ctx(); #endif break; case ST_SHA256_PUBLIC_KEY_REQUEST: req = new MySQLPublicKeyRequest; ((MySQLPublicKeyRequest *)req)->set_sha256(); break; case ST_CSHA2_PUBLIC_KEY_REQUEST: req = new MySQLPublicKeyRequest; ((MySQLPublicKeyRequest *)req)->set_caching_sha2(); break; case ST_RSA_AUTH_REQUEST: req = new MySQLRSAAuthRequest; rsa_auth_req = (MySQLRSAAuthRequest *)req; rsa_auth_req->set_password(password_); rsa_auth_req->set_public_key(std::move(conn->str)); rsa_auth_req->set_seed(conn->seed); break; case ST_CHARSET_REQUEST: req = new MySQLRequest; req->set_query("SET NAMES " + res_charset_); break; case ST_FIRST_USER_REQUEST: if (this->is_fixed_conn()) { auto *target = (RouteManager::RouteTarget *)this->target; /* If it's a transaction task, generate a ECONNRESET error when * the target was reconnected. */ if (target->state) { is_user_request_ = true; errno = ECONNRESET; return NULL; } target->state = 1; } case ST_USER_REQUEST: is_user_request_ = true; req = (MySQLRequest *)this->WFComplexClientTask::message_out(); break; default: assert(0); return NULL; } if (!is_user_request_ && conn->state != ST_CHARSET_REQUEST) req->set_seqid(conn->mysql_seqid); if (!is_ssl_) return req; if (is_user_request_) { conn->wrapper = SSLWrapper(req, conn->ssl); return &conn->wrapper; } else return new MySSLWrapper(req, conn->ssl); } CommMessageIn *ComplexMySQLTask::message_in() { MySQLResponse *resp; if (this->get_seq() == 0) return new MySQLHandshakeResponse; auto *conn = (MyConnection *)this->get_connection(); switch (conn->state) { case ST_SSL_REQUEST: return new SSLHandshaker(conn->ssl); case ST_AUTH_REQUEST: case ST_AUTH_SWITCH_REQUEST: resp = new MySQLAuthResponse; break; case ST_CLEAR_PASSWORD_REQUEST: case ST_RSA_AUTH_REQUEST: resp = new MySQLResponse; break; case ST_SHA256_PUBLIC_KEY_REQUEST: case ST_CSHA2_PUBLIC_KEY_REQUEST: resp = new MySQLPublicKeyResponse; break; case ST_CHARSET_REQUEST: resp = new MySQLResponse; break; case ST_FIRST_USER_REQUEST: case ST_USER_REQUEST: resp = (MySQLResponse *)this->WFComplexClientTask::message_in(); break; default: assert(0); return NULL; } if (!is_ssl_) return resp; if (is_user_request_) { conn->wrapper = SSLWrapper(resp, conn->ssl); return &conn->wrapper; } else return new MySSLWrapper(resp, conn->ssl); } int ComplexMySQLTask::check_handshake(MySQLHandshakeResponse *resp) { SSL *ssl = NULL; if (resp->host_disallowed()) { this->resp = std::move(*(MySQLResponse *)resp); state_ = WFT_STATE_TASK_ERROR; error_ = WFT_ERR_MYSQL_HOST_NOT_ALLOWED; return 0; } if (is_ssl_) { if (resp->get_capability_flags() & 0x800) { static SSL_CTX *ssl_ctx = WFGlobal::get_ssl_client_ctx(); ssl = __create_ssl(ssl_ctx_ ? ssl_ctx_ : ssl_ctx); if (!ssl) { state_ = WFT_STATE_SYS_ERROR; error_ = errno; return 0; } SSL_set_connect_state(ssl); } else { this->resp = std::move(*(MySQLResponse *)resp); state_ = WFT_STATE_TASK_ERROR; error_ = WFT_ERR_MYSQL_SSL_NOT_SUPPORTED; return 0; } } auto *conn = this->get_connection(); auto *my_conn = new MyConnection(ssl); my_conn->str = resp->get_auth_plugin_name(); if (!password_.empty() && my_conn->str == "sha256_password") my_conn->str = "caching_sha2_password"; resp->get_seed(my_conn->seed); my_conn->state = is_ssl_ ? ST_SSL_REQUEST : ST_AUTH_REQUEST; my_conn->mysql_seqid = resp->get_seqid() + 1; conn->set_context(my_conn, [](void *ctx) { auto *my_conn = (MyConnection *)ctx; if (my_conn->ssl) SSL_free(my_conn->ssl); delete my_conn; }); return MYSQL_KEEPALIVE_DEFAULT; } int ComplexMySQLTask::auth_switch(MySQLAuthResponse *resp, MyConnection *conn) { std::string name = resp->get_auth_plugin_name(); if (conn->state != ST_AUTH_REQUEST || (name == "mysql_clear_password" && !is_ssl_)) { state_ = WFT_STATE_SYS_ERROR; error_ = EBADMSG; return 0; } if (password_.empty()) { conn->state = ST_CLEAR_PASSWORD_REQUEST; } else if (name == "sha256_password") { if (is_ssl_) conn->state = ST_CLEAR_PASSWORD_REQUEST; else conn->state = ST_SHA256_PUBLIC_KEY_REQUEST; } else { conn->str = std::move(name); conn->state = ST_AUTH_SWITCH_REQUEST; } resp->get_seed(conn->seed); conn->mysql_seqid = resp->get_seqid() + 1; return MYSQL_KEEPALIVE_DEFAULT; } int ComplexMySQLTask::keep_alive_timeout() { auto *msg = (ProtocolMessage *)this->get_message_in(); MySQLAuthResponse *auth_resp; MySQLResponse *resp; state_ = WFT_STATE_SUCCESS; error_ = 0; if (this->get_seq() == 0) return check_handshake((MySQLHandshakeResponse *)msg); auto *conn = (MyConnection *)this->get_connection(); if (conn->state == ST_SSL_REQUEST) { conn->state = ST_AUTH_REQUEST; conn->mysql_seqid++; return MYSQL_KEEPALIVE_DEFAULT; } if (is_ssl_) resp = (MySQLResponse *)((MySSLWrapper *)msg)->get_msg(); else resp = (MySQLResponse *)msg; switch (conn->state) { case ST_AUTH_REQUEST: case ST_AUTH_SWITCH_REQUEST: case ST_CLEAR_PASSWORD_REQUEST: case ST_RSA_AUTH_REQUEST: if (resp->is_ok_packet()) { if (!res_charset_.empty()) conn->state = ST_CHARSET_REQUEST; else conn->state = ST_FIRST_USER_REQUEST; break; } if (resp->is_error_packet() || conn->state == ST_CLEAR_PASSWORD_REQUEST || conn->state == ST_RSA_AUTH_REQUEST) { this->resp = std::move(*resp); state_ = WFT_STATE_TASK_ERROR; error_ = WFT_ERR_MYSQL_ACCESS_DENIED; return 0; } auth_resp = (MySQLAuthResponse *)resp; if (auth_resp->is_continue()) { if (is_ssl_) conn->state = ST_CLEAR_PASSWORD_REQUEST; else conn->state = ST_CSHA2_PUBLIC_KEY_REQUEST; break; } return auth_switch(auth_resp, conn); case ST_SHA256_PUBLIC_KEY_REQUEST: case ST_CSHA2_PUBLIC_KEY_REQUEST: conn->str = ((MySQLPublicKeyResponse *)resp)->get_public_key(); conn->state = ST_RSA_AUTH_REQUEST; break; case ST_CHARSET_REQUEST: if (!resp->is_ok_packet()) { this->resp = std::move(*resp); state_ = WFT_STATE_TASK_ERROR; error_ = WFT_ERR_MYSQL_INVALID_CHARACTER_SET; return 0; } conn->state = ST_FIRST_USER_REQUEST; return MYSQL_KEEPALIVE_DEFAULT; case ST_FIRST_USER_REQUEST: conn->state = ST_USER_REQUEST; case ST_USER_REQUEST: return this->keep_alive_timeo; default: assert(0); return 0; } conn->mysql_seqid = resp->get_seqid() + 1; return MYSQL_KEEPALIVE_DEFAULT; } int ComplexMySQLTask::first_timeout() { return is_user_request_ ? this->watch_timeo : 0; } /* +--------------------+---------------------+-----+ | CHARACTER_SET_NAME | COLLATION_NAME | ID | +--------------------+---------------------+-----+ | big5 | big5_chinese_ci | 1 | | dec8 | dec8_swedish_ci | 3 | | cp850 | cp850_general_ci | 4 | | hp8 | hp8_english_ci | 6 | | koi8r | koi8r_general_ci | 7 | | latin1 | latin1_swedish_ci | 8 | | latin2 | latin2_general_ci | 9 | | swe7 | swe7_swedish_ci | 10 | | ascii | ascii_general_ci | 11 | | ujis | ujis_japanese_ci | 12 | | sjis | sjis_japanese_ci | 13 | | hebrew | hebrew_general_ci | 16 | | tis620 | tis620_thai_ci | 18 | | euckr | euckr_korean_ci | 19 | | koi8u | koi8u_general_ci | 22 | | gb2312 | gb2312_chinese_ci | 24 | | greek | greek_general_ci | 25 | | cp1250 | cp1250_general_ci | 26 | | gbk | gbk_chinese_ci | 28 | | latin5 | latin5_turkish_ci | 30 | | armscii8 | armscii8_general_ci | 32 | | utf8 | utf8_general_ci | 33 | | ucs2 | ucs2_general_ci | 35 | | cp866 | cp866_general_ci | 36 | | keybcs2 | keybcs2_general_ci | 37 | | macce | macce_general_ci | 38 | | macroman | macroman_general_ci | 39 | | cp852 | cp852_general_ci | 40 | | latin7 | latin7_general_ci | 41 | | cp1251 | cp1251_general_ci | 51 | | utf16 | utf16_general_ci | 54 | | utf16le | utf16le_general_ci | 56 | | cp1256 | cp1256_general_ci | 57 | | cp1257 | cp1257_general_ci | 59 | | utf32 | utf32_general_ci | 60 | | binary | binary | 63 | | geostd8 | geostd8_general_ci | 92 | | cp932 | cp932_japanese_ci | 95 | | eucjpms | eucjpms_japanese_ci | 97 | | gb18030 | gb18030_chinese_ci | 248 | | utf8mb4 | utf8mb4_0900_ai_ci | 255 | +--------------------+---------------------+-----+ */ static int __mysql_get_character_set(const std::string& charset) { static std::unordered_map charset_map = { {"big5", 1}, {"dec8", 3}, {"cp850", 4}, {"hp8", 5}, {"koi8r", 6}, {"latin1", 7}, {"latin2", 8}, {"swe7", 10}, {"ascii", 11}, {"ujis", 12}, {"sjis", 13}, {"hebrew", 16}, {"tis620", 18}, {"euckr", 19}, {"koi8u", 22}, {"gb2312", 24}, {"greek", 25}, {"cp1250", 26}, {"gbk", 28}, {"latin5", 30}, {"armscii8",32}, {"utf8", 33}, {"ucs2", 35}, {"cp866", 36}, {"keybcs2", 37}, {"macce", 38}, {"macroman",39}, {"cp852", 40}, {"latin7", 41}, {"cp1251", 51}, {"utf16", 54}, {"utf16le", 56}, {"cp1256", 57}, {"cp1257", 59}, {"utf32", 60}, {"binary", 63}, {"geostd8", 92}, {"cp932", 95}, {"eucjpms", 97}, {"gb18030", 248}, {"utf8mb4", 255}, }; const auto it = charset_map.find(charset); if (it != charset_map.cend()) return it->second; return -1; } bool ComplexMySQLTask::init_success() { if (uri_.scheme && strcasecmp(uri_.scheme, "mysql") == 0) is_ssl_ = false; else if (uri_.scheme && strcasecmp(uri_.scheme, "mysqls") == 0) is_ssl_ = true; else { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_URI_SCHEME_INVALID; return false; } //todo mysql+unix username_.clear(); password_.clear(); db_.clear(); if (uri_.userinfo) { const char *colon = NULL; const char *pos = uri_.userinfo; while (*pos && *pos != ':') pos++; if (*pos == ':') colon = pos++; if (colon) { if (colon > uri_.userinfo) { username_.assign(uri_.userinfo, colon - uri_.userinfo); StringUtil::url_decode(username_); } if (*pos) { password_.assign(pos); StringUtil::url_decode(password_); } } else { username_.assign(uri_.userinfo); StringUtil::url_decode(username_); } } if (uri_.path && uri_.path[0] == '/' && uri_.path[1]) { db_.assign(uri_.path + 1); StringUtil::url_decode(db_); } std::string transaction; if (uri_.query) { auto query_kv = URIParser::split_query(uri_.query); for (auto& kv : query_kv) { if (strcasecmp(kv.first.c_str(), "transaction") == 0) transaction = std::move(kv.second); else if (strcasecmp(kv.first.c_str(), "character_set") == 0) { character_set_ = __mysql_get_character_set(kv.second); if (character_set_ < 0) { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_MYSQL_INVALID_CHARACTER_SET; return false; } } else if (strcasecmp(kv.first.c_str(), "character_set_results") == 0) res_charset_ = std::move(kv.second); } } size_t info_len = username_.size() + password_.size() + db_.size() + res_charset_.size() + 50; char *info = new char[info_len]; snprintf(info, info_len, "%s|user:%s|pass:%s|db:%s|" "charset:%d|rcharset:%s", is_ssl_ ? "mysqls" : "mysql", username_.c_str(), password_.c_str(), db_.c_str(), character_set_, res_charset_.c_str()); this->WFComplexClientTask::set_transport_type(TT_TCP); if (!transaction.empty()) { this->set_fixed_addr(true); this->set_fixed_conn(true); this->WFComplexClientTask::set_info(info + ("|txn:" + transaction)); } else this->WFComplexClientTask::set_info(info); delete []info; return true; } bool ComplexMySQLTask::finish_once() { if (!is_user_request_) { delete this->get_message_out(); delete this->get_message_in(); if (this->state == WFT_STATE_SUCCESS && state_ != WFT_STATE_SUCCESS) { this->state = state_; this->error = error_; this->disable_retry(); } is_user_request_ = true; return false; } if (this->is_fixed_conn()) { if (this->state != WFT_STATE_SUCCESS || this->keep_alive_timeo == 0) { if (this->target) ((RouteManager::RouteTarget *)this->target)->state = 0; } } return true; } /**********Client Factory**********/ // mysql://user:password@host:port/db_name // url = "mysql://admin:123456@192.168.1.101:3301/test" // url = "mysql://127.0.0.1:3306" WFMySQLTask *WFTaskFactory::create_mysql_task(const std::string& url, int retry_max, mysql_callback_t callback) { auto *task = new ComplexMySQLTask(retry_max, std::move(callback)); ParsedURI uri; URIParser::parse(url, uri); task->init(std::move(uri)); if (task->is_fixed_conn()) task->set_keep_alive(MYSQL_KEEPALIVE_TRANSACTION); else task->set_keep_alive(MYSQL_KEEPALIVE_DEFAULT); return task; } WFMySQLTask *WFTaskFactory::create_mysql_task(const ParsedURI& uri, int retry_max, mysql_callback_t callback) { auto *task = new ComplexMySQLTask(retry_max, std::move(callback)); task->init(uri); if (task->is_fixed_conn()) task->set_keep_alive(MYSQL_KEEPALIVE_TRANSACTION); else task->set_keep_alive(MYSQL_KEEPALIVE_DEFAULT); return task; } /**********Server**********/ class WFMySQLServerTask : public WFServerTask { public: WFMySQLServerTask(CommService *service, std::function& proc): WFServerTask(service, WFGlobal::get_scheduler(), proc) {} protected: virtual SubTask *done(); virtual CommMessageOut *message_out(); virtual CommMessageIn *message_in(); }; SubTask *WFMySQLServerTask::done() { if (this->get_seq() == 0) delete this->get_message_in(); return this->WFServerTask::done(); } CommMessageOut *WFMySQLServerTask::message_out() { long long seqid = this->get_seq(); if (seqid == 0) this->resp.set_ok_packet(); // always success return this->WFServerTask::message_out(); } CommMessageIn *WFMySQLServerTask::message_in() { long long seqid = this->get_seq(); if (seqid == 0) return new MySQLAuthRequest; return this->WFServerTask::message_in(); } /**********Server Factory**********/ WFMySQLTask *WFServerTaskFactory::create_mysql_task(CommService *service, std::function& process) { return new WFMySQLServerTask(service, process); } workflow-0.11.8/src/factory/RedisTaskImpl.cc000066400000000000000000000236211476003635400207470ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) Li Yingxin (liyingxin@sogou-inc.com) Liu Kai (liukaidx@sogou-inc.com) Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include "PackageWrapper.h" #include "WFTaskError.h" #include "WFTaskFactory.h" #include "StringUtil.h" #include "RedisTaskImpl.inl" using namespace protocol; #define REDIS_KEEPALIVE_DEFAULT (60 * 1000) #define REDIS_REDIRECT_MAX 3 /**********Client**********/ class ComplexRedisTask : public WFComplexClientTask { public: ComplexRedisTask(int retry_max, redis_callback_t&& callback): WFComplexClientTask(retry_max, std::move(callback)), db_num_(0), is_user_request_(true), redirect_count_(0) {} protected: virtual bool check_request(); virtual CommMessageOut *message_out(); virtual CommMessageIn *message_in(); virtual int keep_alive_timeout(); virtual int first_timeout(); virtual bool init_success(); virtual bool finish_once(); protected: bool need_redirect(); std::string username_; std::string password_; int db_num_; bool succ_; bool is_user_request_; int redirect_count_; }; bool ComplexRedisTask::check_request() { std::string command; if (this->req.get_command(command) && (strcasecmp(command.c_str(), "AUTH") == 0 || strcasecmp(command.c_str(), "SELECT") == 0 || strcasecmp(command.c_str(), "RESET") == 0 || strcasecmp(command.c_str(), "ASKING") == 0)) { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_REDIS_COMMAND_DISALLOWED; return false; } return true; } CommMessageOut *ComplexRedisTask::message_out() { long long seqid = this->get_seq(); if (seqid <= 1) { if (seqid == 0 && (!password_.empty() || !username_.empty())) { auto *auth_req = new RedisRequest; if (!username_.empty()) auth_req->set_request("AUTH", {username_, password_}); else auth_req->set_request("AUTH", {password_}); succ_ = false; is_user_request_ = false; return auth_req; } if (db_num_ > 0 && (seqid == 0 || !password_.empty() || !username_.empty())) { auto *select_req = new RedisRequest; char buf[32]; sprintf(buf, "%d", db_num_); select_req->set_request("SELECT", {buf}); succ_ = false; is_user_request_ = false; return select_req; } } return this->WFComplexClientTask::message_out(); } CommMessageIn *ComplexRedisTask::message_in() { RedisRequest *req = this->get_req(); RedisResponse *resp = this->get_resp(); if (is_user_request_) resp->set_asking(req->is_asking()); else resp->set_asking(false); return this->WFComplexClientTask::message_in(); } int ComplexRedisTask::keep_alive_timeout() { if (this->is_user_request_) return this->keep_alive_timeo; RedisResponse *resp = this->get_resp(); succ_ = (resp->parse_success() && resp->result_ptr()->type != REDIS_REPLY_TYPE_ERROR); return succ_ ? REDIS_KEEPALIVE_DEFAULT : 0; } int ComplexRedisTask::first_timeout() { return is_user_request_ ? this->watch_timeo : 0; } bool ComplexRedisTask::init_success() { enum TransportType type; if (uri_.scheme && strcasecmp(uri_.scheme, "redis") == 0) type = TT_TCP; else if (uri_.scheme && strcasecmp(uri_.scheme, "rediss") == 0) type = TT_TCP_SSL; else { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_URI_SCHEME_INVALID; return false; } //todo redis+unix //https://stackoverflow.com/questions/26964595/whats-the-correct-way-to-use-a-unix-domain-socket-in-requests-framework //https://stackoverflow.com/questions/27037990/connecting-to-postgres-via-database-url-and-unix-socket-in-rails if (uri_.userinfo) { char *p = strchr(uri_.userinfo, ':'); if (p) { username_.assign(uri_.userinfo, p); password_.assign(p + 1); StringUtil::url_decode(username_); StringUtil::url_decode(password_); } else { username_.assign(uri_.userinfo); StringUtil::url_decode(username_); } } if (uri_.path && uri_.path[0] == '/' && uri_.path[1]) db_num_ = atoi(uri_.path + 1); size_t info_len = username_.size() + password_.size() + 32 + 32; char *info = new char[info_len]; sprintf(info, "redis|user:%s|pass:%s|db:%d", username_.c_str(), password_.c_str(), db_num_); this->WFComplexClientTask::set_transport_type(type); this->WFComplexClientTask::set_info(info); delete []info; return true; } bool ComplexRedisTask::need_redirect() { RedisRequest *client_req = this->get_req(); RedisResponse *client_resp = this->get_resp(); redis_reply_t *reply = client_resp->result_ptr(); if (reply->type == REDIS_REPLY_TYPE_ERROR) { if (reply->str == NULL) return false; if (strncasecmp(reply->str, "NOAUTH ", 7) == 0) { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_REDIS_ACCESS_DENIED; return false; } bool asking = false; if (strncasecmp(reply->str, "ASK ", 4) == 0) asking = true; else if (strncasecmp(reply->str, "MOVED ", 6) != 0) return false; if (redirect_count_ >= REDIS_REDIRECT_MAX) return false; std::string err_str(reply->str, reply->len); auto split_result = StringUtil::split_filter_empty(err_str, ' '); if (split_result.size() == 3) { client_req->set_asking(asking); // format: COMMAND SLOT HOSTPORT // example: MOVED/ASK 123 127.0.0.1:6379 std::string& hostport = split_result[2]; redirect_count_++; ParsedURI uri; std::string url; url.append(uri_.scheme); url.append("://"); url.append(hostport); URIParser::parse(url, uri); std::swap(uri.host, uri_.host); std::swap(uri.port, uri_.port); std::swap(uri.state, uri_.state); std::swap(uri.error, uri_.error); return true; } } return false; } bool ComplexRedisTask::finish_once() { if (!is_user_request_) { is_user_request_ = true; delete this->get_message_out(); if (this->state == WFT_STATE_SUCCESS) { if (succ_) this->clear_resp(); else { this->disable_retry(); this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_REDIS_ACCESS_DENIED; } } return false; } if (this->state == WFT_STATE_SUCCESS) { if (need_redirect()) this->set_redirect(uri_); else if (this->state != WFT_STATE_SUCCESS) this->disable_retry(); } return true; } /****** Redis Subscribe ******/ class ComplexRedisSubscribeTask : public ComplexRedisTask { public: virtual int push(const void *buf, size_t size) { if (finished_) { errno = ENOENT; return -1; } if (!watching_) { errno = EAGAIN; return -1; } return this->scheduler->push(buf, size, this); } protected: virtual CommMessageIn *message_in() { if (!is_user_request_) return this->ComplexRedisTask::message_in(); return &wrapper_; } virtual int first_timeout() { return watching_ ? this->watch_timeo : 0; } protected: class SubscribeWrapper : public PackageWrapper { protected: virtual ProtocolMessage *next_in(ProtocolMessage *message); protected: ComplexRedisSubscribeTask *task_; public: SubscribeWrapper(ComplexRedisSubscribeTask *task) : PackageWrapper(task->get_resp()) { task_ = task; } }; protected: SubscribeWrapper wrapper_; bool watching_; bool finished_; std::function extract_; public: ComplexRedisSubscribeTask(std::function&& extract, redis_callback_t&& callback) : ComplexRedisTask(0, std::move(callback)), wrapper_(this), extract_(std::move(extract)) { watching_ = false; finished_ = false; } }; ProtocolMessage * ComplexRedisSubscribeTask::SubscribeWrapper::next_in(ProtocolMessage *message) { redis_reply_t *reply = task_->resp.result_ptr(); if (reply->type != REDIS_REPLY_TYPE_ARRAY) { task_->finished_ = true; return NULL; } if (reply->elements == 3 && reply->element[2]->type == REDIS_REPLY_TYPE_INTEGER && reply->element[2]->integer == 0) { task_->finished_ = true; } task_->watching_ = true; task_->extract_(task_); RedisResponse resp; *(protocol::ProtocolMessage *)&resp = std::move(task_->resp); task_->resp = std::move(resp); return task_->finished_ ? NULL : &task_->resp; } /**********Factory**********/ // redis://:password@host:port/db_num // url = "redis://:admin@192.168.1.101:6001/3" // url = "redis://127.0.0.1:6379" WFRedisTask *WFTaskFactory::create_redis_task(const std::string& url, int retry_max, redis_callback_t callback) { auto *task = new ComplexRedisTask(retry_max, std::move(callback)); ParsedURI uri; URIParser::parse(url, uri); task->init(std::move(uri)); task->set_keep_alive(REDIS_KEEPALIVE_DEFAULT); return task; } WFRedisTask *WFTaskFactory::create_redis_task(const ParsedURI& uri, int retry_max, redis_callback_t callback) { auto *task = new ComplexRedisTask(retry_max, std::move(callback)); task->init(uri); task->set_keep_alive(REDIS_KEEPALIVE_DEFAULT); return task; } WFRedisTask * __WFRedisTaskFactory::create_subscribe_task(const std::string& url, extract_t extract, redis_callback_t callback) { auto *task = new ComplexRedisSubscribeTask(std::move(extract), std::move(callback)); ParsedURI uri; URIParser::parse(url, uri); task->init(std::move(uri)); return task; } WFRedisTask * __WFRedisTaskFactory::create_subscribe_task(const ParsedURI& uri, extract_t extract, redis_callback_t callback) { auto *task = new ComplexRedisSubscribeTask(std::move(extract), std::move(callback)); task->init(uri); return task; } workflow-0.11.8/src/factory/RedisTaskImpl.inl000066400000000000000000000020751476003635400211440ustar00rootroot00000000000000/* Copyright (c) 2024 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) */ #include "WFTaskFactory.h" // Internal, for WFRedisSubscribeTask only. class __WFRedisTaskFactory { private: using extract_t = std::function; public: static WFRedisTask *create_subscribe_task(const std::string& url, extract_t extract, redis_callback_t callback); static WFRedisTask *create_subscribe_task(const ParsedURI& uri, extract_t extract, redis_callback_t callback); }; workflow-0.11.8/src/factory/WFAlgoTaskFactory.h000066400000000000000000000140151476003635400213650ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _WFALGOTASKFACTORY_H_ #define _WFALGOTASKFACTORY_H_ #include #include #include #include "WFTask.h" namespace algorithm { template struct SortInput { T *first; T *last; }; template struct SortOutput { T *first; T *last; }; template struct MergeInput { T *first1; T *last1; T *first2; T *last2; T *d_first; }; template struct MergeOutput { T *first; T *last; }; template struct ShuffleInput { T *first; T *last; }; template struct ShuffleOutput { T *first; T *last; }; template struct RemoveInput { T *first; T *last; T value; }; template struct RemoveOutput { T *first; T *last; }; template struct UniqueInput { T *first; T *last; }; template struct UniqueOutput { T *first; T *last; }; template struct ReverseInput { T *first; T *last; }; template struct ReverseOutput { T *first; T *last; }; template struct RotateInput { T *first; T *middle; T *last; }; template struct RotateOutput { T *first; T *last; }; template using ReduceInput = std::vector>; template using ReduceOutput = std::vector>; } /* namespace algorithm */ template using WFSortTask = WFThreadTask, algorithm::SortOutput>; template using sort_callback_t = std::function *)>; template using WFMergeTask = WFThreadTask, algorithm::MergeOutput>; template using merge_callback_t = std::function *)>; template using WFShuffleTask = WFThreadTask, algorithm::ShuffleOutput>; template using shuffle_callback_t = std::function *)>; template using WFRemoveTask = WFThreadTask, algorithm::RemoveOutput>; template using remove_callback_t = std::function *)>; template using WFUniqueTask = WFThreadTask, algorithm::UniqueOutput>; template using unique_callback_t = std::function *)>; template using WFReverseTask = WFThreadTask, algorithm::ReverseOutput>; template using reverse_callback_t = std::function *)>; template using WFRotateTask = WFThreadTask, algorithm::RotateOutput>; template using rotate_callback_t = std::function *)>; class WFAlgoTaskFactory { public: template> static WFSortTask *create_sort_task(const std::string& queue_name, T *first, T *last, CB callback); template> static WFSortTask *create_sort_task(const std::string& queue_name, T *first, T *last, CMP compare, CB callback); template> static WFSortTask *create_psort_task(const std::string& queue_name, T *first, T *last, CB callback); template> static WFSortTask *create_psort_task(const std::string& queue_name, T *first, T *last, CMP compare, CB callback); template> static WFMergeTask *create_merge_task(const std::string& queue_name, T *first1, T *last1, T *first2, T *last2, T *d_first, CB callback); template> static WFMergeTask *create_merge_task(const std::string& queue_name, T *first1, T *last1, T *first2, T *last2, T *d_first, CMP compare, CB callback); template> static WFShuffleTask *create_shuffle_task(const std::string& queue_name, T *first, T *last, CB callback); template> static WFShuffleTask *create_shuffle_task(const std::string& queue_name, T *first, T *last, URBG generator, CB callback); template> static WFRemoveTask *create_remove_task(const std::string& queue_name, T *first, T *last, T value, CB callback); template> static WFUniqueTask *create_unique_task(const std::string& queue_name, T *first, T *last, CB callback); template> static WFReverseTask *create_reverse_task(const std::string& queue_name, T *first, T *last, CB callback); template> static WFRotateTask *create_rotate_task(const std::string& queue_name, T *first, T *middle, T *last, CB callback); }; #include "WFAlgoTaskFactory.inl" #endif workflow-0.11.8/src/factory/WFAlgoTaskFactory.inl000066400000000000000000000403461476003635400217260ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include #include #include "Workflow.h" #include "WFGlobal.h" /********** Classes without CMP **********/ template class __WFSortTask : public WFSortTask { protected: virtual void execute() { std::sort(this->input.first, this->input.last); this->output.first = this->input.first; this->output.last = this->input.last; } public: __WFSortTask(ExecQueue *queue, Executor *executor, T *first, T *last, sort_callback_t&& cb) : WFSortTask(queue, executor, std::move(cb)) { this->input.first = first; this->input.last = last; this->output.first = NULL; this->output.last = NULL; } }; template class __WFMergeTask : public WFMergeTask { protected: virtual void execute(); public: __WFMergeTask(ExecQueue *queue, Executor *executor, T *first1, T *last1, T *first2, T *last2, T *d_first, merge_callback_t&& cb) : WFMergeTask(queue, executor, std::move(cb)) { this->input.first1 = first1; this->input.last1 = last1; this->input.first2 = first2; this->input.last2 = last2; this->input.d_first = d_first; this->output.first = NULL; this->output.last = NULL; } }; template void __WFMergeTask::execute() { auto *input = &this->input; auto *output = &this->output; if (input->first1 == input->d_first && input->last1 == input->first2) { std::inplace_merge(input->first1, input->first2, input->last2); output->last = input->last2; } else if (input->first2 == input->d_first && input->last2 == input->first1) { std::inplace_merge(input->first2, input->first1, input->last1); output->last = input->last1; } else { output->last = std::merge(input->first1, input->last1, input->first2, input->last2, input->d_first); } output->first = input->d_first; } template class __WFParSortTask : public __WFSortTask { public: virtual void dispatch(); protected: virtual SubTask *done() { if (this->flag) return series_of(this)->pop(); return this->WFSortTask::done(); } virtual void execute(); protected: int depth; int flag; public: __WFParSortTask(ExecQueue *queue, Executor *executor, T *first, T *last, int depth, sort_callback_t&& cb) : __WFSortTask(queue, executor, first, last, std::move(cb)) { this->depth = depth; this->flag = 0; } }; template void __WFParSortTask::dispatch() { size_t n = this->input.last - this->input.first; if (!this->flag && this->depth < 7 && n >= 32) { SeriesWork *series = series_of(this); T *middle = this->input.first + n / 2; auto *task1 = new __WFParSortTask(this->queue, this->executor, this->input.first, middle, this->depth + 1, nullptr); auto *task2 = new __WFParSortTask(this->queue, this->executor, middle, this->input.last, this->depth + 1, nullptr); SeriesWork *sub_series[2] = { Workflow::create_series_work(task1, nullptr), Workflow::create_series_work(task2, nullptr) }; ParallelWork *parallel = Workflow::create_parallel_work(sub_series, 2, nullptr); series->push_front(this); series->push_front(parallel); this->flag = 1; this->subtask_done(); } else this->__WFSortTask::dispatch(); } template void __WFParSortTask::execute() { if (this->flag) { size_t n = this->input.last - this->input.first; T *middle = this->input.first + n / 2; std::inplace_merge(this->input.first, middle, this->input.last); this->output.first = this->input.first; this->output.last = this->input.last; this->flag = 0; } else this->__WFSortTask::execute(); } /********** Classes with CMP **********/ template class __WFSortTaskCmp : public __WFSortTask { protected: virtual void execute() { std::sort(this->input.first, this->input.last, std::move(this->compare)); this->output.first = this->input.first; this->output.last = this->input.last; } protected: CMP compare; public: __WFSortTaskCmp(ExecQueue *queue, Executor *executor, T *first, T *last, CMP&& cmp, sort_callback_t&& cb) : __WFSortTask(queue, executor, first, last, std::move(cb)), compare(std::move(cmp)) { } }; template class __WFMergeTaskCmp : public __WFMergeTask { protected: virtual void execute(); protected: CMP compare; public: __WFMergeTaskCmp(ExecQueue *queue, Executor *executor, T *first1, T *last1, T *first2, T *last2, T *d_first, CMP&& cmp, merge_callback_t&& cb) : __WFMergeTask(queue, executor, first1, last1, first2, last2, d_first, std::move(cb)), compare(std::move(cmp)) { } }; template void __WFMergeTaskCmp::execute() { auto *input = &this->input; auto *output = &this->output; if (input->first1 == input->d_first && input->last1 == input->first2) { std::inplace_merge(input->first1, input->first2, input->last2, std::move(this->compare)); output->last = input->last2; } else if (input->first2 == input->d_first && input->last2 == input->first1) { std::inplace_merge(input->first2, input->first1, input->last1, std::move(this->compare)); output->last = input->last1; } else { output->last = std::merge(input->first1, input->last1, input->first2, input->last2, input->d_first, std::move(this->compare)); } output->first = input->d_first; } template class __WFParSortTaskCmp : public __WFSortTaskCmp { public: virtual void dispatch(); protected: virtual SubTask *done() { if (this->flag) return series_of(this)->pop(); return this->WFSortTask::done(); } virtual void execute(); protected: int depth; int flag; public: __WFParSortTaskCmp(ExecQueue *queue, Executor *executor, T *first, T *last, CMP cmp, int depth, sort_callback_t&& cb) : __WFSortTaskCmp(queue, executor, first, last, std::move(cmp), std::move(cb)) { this->depth = depth; this->flag = 0; } }; template void __WFParSortTaskCmp::dispatch() { size_t n = this->input.last - this->input.first; if (!this->flag && this->depth < 7 && n >= 32) { SeriesWork *series = series_of(this); T *middle = this->input.first + n / 2; auto *task1 = new __WFParSortTaskCmp(this->queue, this->executor, this->input.first, middle, this->compare, this->depth + 1, nullptr); auto *task2 = new __WFParSortTaskCmp(this->queue, this->executor, middle, this->input.last, this->compare, this->depth + 1, nullptr); SeriesWork *sub_series[2] = { Workflow::create_series_work(task1, nullptr), Workflow::create_series_work(task2, nullptr) }; ParallelWork *parallel = Workflow::create_parallel_work(sub_series, 2, nullptr); series->push_front(this); series->push_front(parallel); this->flag = 1; this->subtask_done(); } else this->__WFSortTaskCmp::dispatch(); } template void __WFParSortTaskCmp::execute() { if (this->flag) { size_t n = this->input.last - this->input.first; T *middle = this->input.first + n / 2; std::inplace_merge(this->input.first, middle, this->input.last, std::move(this->compare)); this->output.first = this->input.first; this->output.last = this->input.last; this->flag = 0; } else this->__WFSortTaskCmp::execute(); } /********** Factory functions without CMP **********/ template WFSortTask *WFAlgoTaskFactory::create_sort_task(const std::string& name, T *first, T *last, CB callback) { return new __WFSortTask(WFGlobal::get_exec_queue(name), WFGlobal::get_compute_executor(), first, last, std::move(callback)); } template WFMergeTask *WFAlgoTaskFactory::create_merge_task(const std::string& name, T *first1, T *last1, T *first2, T *last2, T *d_first, CB callback) { return new __WFMergeTask(WFGlobal::get_exec_queue(name), WFGlobal::get_compute_executor(), first1, last1, first2, last2, d_first, std::move(callback)); } template WFSortTask *WFAlgoTaskFactory::create_psort_task(const std::string& name, T *first, T *last, CB callback) { return new __WFParSortTask(WFGlobal::get_exec_queue(name), WFGlobal::get_compute_executor(), first, last, 0, std::move(callback)); } /********** Factory functions with CMP **********/ template WFSortTask *WFAlgoTaskFactory::create_sort_task(const std::string& name, T *first, T *last, CMP compare, CB callback) { return new __WFSortTaskCmp(WFGlobal::get_exec_queue(name), WFGlobal::get_compute_executor(), first, last, std::move(compare), std::move(callback)); } template WFMergeTask *WFAlgoTaskFactory::create_merge_task(const std::string& name, T *first1, T *last1, T *first2, T *last2, T *d_first, CMP compare, CB callback) { return new __WFMergeTaskCmp(WFGlobal::get_exec_queue(name), WFGlobal::get_compute_executor(), first1, last1, first2, last2, d_first, std::move(compare), std::move(callback)); } template WFSortTask *WFAlgoTaskFactory::create_psort_task(const std::string& name, T *first, T *last, CMP compare, CB callback) { return new __WFParSortTaskCmp(WFGlobal::get_exec_queue(name), WFGlobal::get_compute_executor(), first, last, std::move(compare), 0, std::move(callback)); } /****************** Shuffle ******************/ template class __WFShuffleTask : public WFShuffleTask { protected: virtual void execute() { std::shuffle(this->input.first, this->input.last, std::mt19937_64(rand())); this->output.first = this->input.first; this->output.last = this->input.last; } public: __WFShuffleTask(ExecQueue *queue, Executor *executor, T *first, T *last, shuffle_callback_t&& cb) : WFShuffleTask(queue, executor, std::move(cb)) { this->input.first = first; this->input.last = last; this->output.first = NULL; this->output.last = NULL; } }; template class __WFShuffleTaskGen : public __WFShuffleTask { protected: virtual void execute() { std::shuffle(this->input.first, this->input.last, std::move(this->generator)); this->output.first = this->input.first; this->output.last = this->input.last; } protected: URBG generator; public: __WFShuffleTaskGen(ExecQueue *queue, Executor *executor, T *first, T *last, URBG&& gen, shuffle_callback_t&& cb) : __WFShuffleTask(queue, executor, std::move(cb)), generator(std::move(gen)) { } }; template WFShuffleTask *WFAlgoTaskFactory::create_shuffle_task(const std::string& name, T *first, T *last, CB callback) { return new __WFShuffleTask(WFGlobal::get_exec_queue(name), WFGlobal::get_compute_executor(), first, last, std::move(callback)); } template WFShuffleTask *WFAlgoTaskFactory::create_shuffle_task(const std::string& name, T *first, T *last, URBG generator, CB callback) { return new __WFShuffleTaskGen(WFGlobal::get_exec_queue(name), WFGlobal::get_compute_executor(), first, last, std::move(generator), std::move(callback)); } /****************** Remove ******************/ template class __WFRemoveTask : public WFRemoveTask { protected: virtual void execute() { this->output.last = std::remove(this->input.first, this->input.last, this->input.value); this->output.first = this->input.first; } public: __WFRemoveTask(ExecQueue *queue, Executor *executor, T *first, T *last, T&& value, remove_callback_t&& cb) : WFRemoveTask(queue, executor, std::move(cb)) { this->input.first = first; this->input.last = last; this->input.value = std::move(value); this->output.first = NULL; this->output.last = NULL; } }; template WFRemoveTask *WFAlgoTaskFactory::create_remove_task(const std::string& name, T *first, T *last, T value, CB callback) { return new __WFRemoveTask(WFGlobal::get_exec_queue(name), WFGlobal::get_compute_executor(), first, last, std::move(value), std::move(callback)); } /****************** Unique ******************/ template class __WFUniqueTask : public WFUniqueTask { protected: virtual void execute() { this->output.last = std::unique(this->input.first, this->input.last); this->output.first = this->input.first; } public: __WFUniqueTask(ExecQueue *queue, Executor *executor, T *first, T *last, unique_callback_t&& cb) : WFUniqueTask(queue, executor, std::move(cb)) { this->input.first = first; this->input.last = last; this->output.first = NULL; this->output.last = NULL; } }; template WFUniqueTask *WFAlgoTaskFactory::create_unique_task(const std::string& name, T *first, T *last, CB callback) { return new __WFUniqueTask(WFGlobal::get_exec_queue(name), WFGlobal::get_compute_executor(), first, last, std::move(callback)); } /****************** Reverse ******************/ template class __WFReverseTask : public WFReverseTask { protected: virtual void execute() { std::reverse(this->input.first, this->input.last); this->output.first = this->input.first; this->output.last = this->input.last; } public: __WFReverseTask(ExecQueue *queue, Executor *executor, T *first, T *last, reverse_callback_t&& cb) : WFReverseTask(queue, executor, std::move(cb)) { this->input.first = first; this->input.last = last; this->output.first = NULL; this->output.last = NULL; } }; template WFReverseTask *WFAlgoTaskFactory::create_reverse_task(const std::string& name, T *first, T *last, CB callback) { return new __WFReverseTask(WFGlobal::get_exec_queue(name), WFGlobal::get_compute_executor(), first, last, std::move(callback)); } /****************** Rotate ******************/ template class __WFRotateTask : public WFRotateTask { protected: virtual void execute() { std::rotate(this->input.first, this->input.middle, this->input.last); this->output.first = this->input.first; this->output.last = this->input.last; } public: __WFRotateTask(ExecQueue *queue, Executor *executor, T *first, T* middle, T *last, rotate_callback_t&& cb) : WFRotateTask(queue, executor, std::move(cb)) { this->input.first = first; this->input.middle = middle; this->input.last = last; this->output.first = NULL; this->output.last = NULL; } }; template WFRotateTask *WFAlgoTaskFactory::create_rotate_task(const std::string& name, T *first, T *middle, T *last, CB callback) { return new __WFRotateTask(WFGlobal::get_exec_queue(name), WFGlobal::get_compute_executor(), first, middle, last, std::move(callback)); } workflow-0.11.8/src/factory/WFConnection.h000066400000000000000000000033731476003635400204340ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _WFCONNECTION_H_ #define _WFCONNECTION_H_ #include #include #include #include "Communicator.h" class WFConnection : public CommConnection { public: void *get_context() const { return this->context; } void set_context(void *context, std::function deleter) { this->context = context; this->deleter = std::move(deleter); } void set_context(void *context) { this->context = context; } void *test_set_context(void *test_context, void *new_context, std::function deleter) { if (this->context.compare_exchange_strong(test_context, new_context)) { this->deleter = std::move(deleter); return new_context; } return test_context; } void *test_set_context(void *test_context, void *new_context) { if (this->context.compare_exchange_strong(test_context, new_context)) return new_context; return test_context; } private: std::atomic context; std::function deleter; public: WFConnection() : context(NULL) { } protected: virtual ~WFConnection() { if (this->deleter) this->deleter(this->context); } }; #endif workflow-0.11.8/src/factory/WFGraphTask.cc000066400000000000000000000041231476003635400203510ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include "Workflow.h" #include "WFGraphTask.h" SubTask *WFGraphNode::done() { SeriesWork *series = series_of(this); if (!this->user_data) { this->value = 1; this->user_data = (void *)1; } else delete this; return series->pop(); } WFGraphNode::~WFGraphNode() { if (this->user_data) { if (series_of(this)->is_canceled()) { for (WFGraphNode *node : this->successors) series_of(node)->SeriesWork::cancel(); } for (WFGraphNode *node : this->successors) node->WFCounterTask::count(); } } WFGraphNode& WFGraphTask::create_graph_node(SubTask *task) { WFGraphNode *node = new WFGraphNode; SeriesWork *series = Workflow::create_series_work(node, node, nullptr); series->push_back(task); this->parallel->add_series(series); return *node; } void WFGraphTask::dispatch() { SeriesWork *series = series_of(this); if (this->parallel) { series->push_front(this); series->push_front(this->parallel); this->parallel = NULL; } else this->state = WFT_STATE_SUCCESS; this->subtask_done(); } SubTask *WFGraphTask::done() { SeriesWork *series = series_of(this); if (this->state == WFT_STATE_SUCCESS) { if (this->callback) this->callback(this); delete this; } return series->pop(); } WFGraphTask::~WFGraphTask() { SeriesWork *series; size_t i; if (this->parallel) { for (i = 0; i < this->parallel->size(); i++) { series = this->parallel->series_at(i); series->unset_last_task(); } this->parallel->dismiss(); } } workflow-0.11.8/src/factory/WFGraphTask.h000066400000000000000000000041411476003635400202130ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _WFGRAPHTASK_H_ #define _WFGRAPHTASK_H_ #include #include #include #include "Workflow.h" #include "WFTask.h" class WFGraphNode : protected WFCounterTask { public: void precede(WFGraphNode& node) { node.value++; this->successors.push_back(&node); } void succeed(WFGraphNode& node) { node.precede(*this); } protected: virtual SubTask *done(); protected: std::vector successors; protected: WFGraphNode() : WFCounterTask(0, nullptr) { } virtual ~WFGraphNode(); friend class WFGraphTask; }; static inline WFGraphNode& operator --(WFGraphNode& node, int) { return node; } static inline WFGraphNode& operator > (WFGraphNode& prec, WFGraphNode& succ) { prec.precede(succ); return succ; } static inline WFGraphNode& operator < (WFGraphNode& succ, WFGraphNode& prec) { succ.succeed(prec); return prec; } static inline WFGraphNode& operator --(WFGraphNode& node) { return node; } class WFGraphTask : public WFGenericTask { public: WFGraphNode& create_graph_node(SubTask *task); public: void set_callback(std::function cb) { this->callback = std::move(cb); } protected: virtual void dispatch(); virtual SubTask *done(); protected: ParallelWork *parallel; std::function callback; public: WFGraphTask(std::function&& cb) : callback(std::move(cb)) { this->parallel = Workflow::create_parallel_work(nullptr); } protected: virtual ~WFGraphTask(); }; #endif workflow-0.11.8/src/factory/WFMessageQueue.cc000066400000000000000000000040461476003635400210620ustar00rootroot00000000000000/* Copyright (c) 2022 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) */ #include "list.h" #include "WFTask.h" #include "WFMessageQueue.h" class __MQConditional : public WFConditional { public: struct list_head list; struct WFMessageQueue::Data *data; public: virtual void dispatch(); virtual void signal(void *msg) { } public: __MQConditional(SubTask *task, void **msgbuf, struct WFMessageQueue::Data *data) : WFConditional(task, msgbuf) { this->data = data; } __MQConditional(SubTask *task, struct WFMessageQueue::Data *data) : WFConditional(task) { this->data = data; } }; void __MQConditional::dispatch() { struct WFMessageQueue::Data *data = this->data; data->mutex.lock(); if (!list_empty(&data->msg_list)) this->WFConditional::signal(data->pop()); else list_add_tail(&this->list, &data->wait_list); data->mutex.unlock(); this->WFConditional::dispatch(); } WFConditional *WFMessageQueue::get(SubTask *task, void **msgbuf) { return new __MQConditional(task, msgbuf, &this->data); } WFConditional *WFMessageQueue::get(SubTask *task) { return new __MQConditional(task, &this->data); } void WFMessageQueue::post(void *msg) { struct WFMessageQueue::Data *data = &this->data; WFConditional *cond; data->mutex.lock(); if (!list_empty(&data->wait_list)) { cond = list_entry(data->wait_list.next, __MQConditional, list); list_del(data->wait_list.next); } else { cond = NULL; this->push(msg); } data->mutex.unlock(); if (cond) cond->WFConditional::signal(msg); } workflow-0.11.8/src/factory/WFMessageQueue.h000066400000000000000000000034131476003635400207210ustar00rootroot00000000000000/* Copyright (c) 2022 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) */ #ifndef _WFMESSAGEQUEUE_H_ #define _WFMESSAGEQUEUE_H_ #include #include "list.h" #include "WFTask.h" class WFMessageQueue { public: WFConditional *get(SubTask *task, void **msgbuf); WFConditional *get(SubTask *task); void post(void *msg); public: struct Data { void *pop() { return this->queue->pop(); } void push(void *msg) { this->queue->push(msg); } struct list_head msg_list; struct list_head wait_list; std::mutex mutex; WFMessageQueue *queue; }; protected: struct MessageEntry { struct list_head list; void *msg; }; protected: virtual void *pop() { struct MessageEntry *entry; void *msg; entry = list_entry(this->data.msg_list.next, struct MessageEntry, list); list_del(&entry->list); msg = entry->msg; delete entry; return msg; } virtual void push(void *msg) { struct MessageEntry *entry = new struct MessageEntry; entry->msg = msg; list_add_tail(&entry->list, &this->data.msg_list); } protected: struct Data data; public: WFMessageQueue() { INIT_LIST_HEAD(&this->data.msg_list); INIT_LIST_HEAD(&this->data.wait_list); this->data.queue = this; } virtual ~WFMessageQueue() { } }; #endif workflow-0.11.8/src/factory/WFOperator.h000066400000000000000000000214031476003635400201220ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #ifndef _WFOPERATOR_H_ #define _WFOPERATOR_H_ #include "Workflow.h" /** * @file WFOperator.h * @brief Workflow Series/Parallel/Task operator */ /** * @brief S=S>P * @note Equivalent to x.push_back(&y) */ static inline SeriesWork& operator>(SeriesWork& x, ParallelWork& y); /** * @brief S=P>S * @note Equivalent to y.push_front(&x) */ static inline SeriesWork& operator>(ParallelWork& x, SeriesWork& y); /** * @brief S=S>t * @note Equivalent to x.push_back(&y) */ static inline SeriesWork& operator>(SeriesWork& x, SubTask& y); /** * @brief S=S>t * @note Equivalent to x.push_back(y) */ static inline SeriesWork& operator>(SeriesWork& x, SubTask *y); /** * @brief S=t>S * @note Equivalent to y.push_front(&x) */ static inline SeriesWork& operator>(SubTask& x, SeriesWork& y); /** * @brief S=t>S * @note Equivalent to y.push_front(x) */ static inline SeriesWork& operator>(SubTask *x, SeriesWork& y); /** * @brief S=P>P * @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(&y) */ static inline SeriesWork& operator>(ParallelWork& x, ParallelWork& y); /** * @brief S=P>t * @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(&y) */ static inline SeriesWork& operator>(ParallelWork& x, SubTask& y); /** * @brief S=P>t * @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(y) */ static inline SeriesWork& operator>(ParallelWork& x, SubTask *y); /** * @brief S=t>P * @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(&y) */ static inline SeriesWork& operator>(SubTask& x, ParallelWork& y); /** * @brief S=t>P * @note Equivalent to Workflow::create_series_work(x, nullptr)->push_back(&y) */ static inline SeriesWork& operator>(SubTask *x, ParallelWork& y); /** * @brief S=t>t * @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(&y) */ static inline SeriesWork& operator>(SubTask& x, SubTask& y); /** * @brief S=t>t * @note Equivalent to Workflow::create_series_work(&x, nullptr)->push_back(y) */ static inline SeriesWork& operator>(SubTask& x, SubTask *y); /** * @brief S=t>t * @note Equivalent to Workflow::create_series_work(x, nullptr)->push_back(&y) */ static inline SeriesWork& operator>(SubTask *x, SubTask& y); //static inline SeriesWork& operator>(SubTask *x, SubTask *y);//compile error! /** * @brief P=P*S * @note Equivalent to x.add_series(&y) */ static inline ParallelWork& operator*(ParallelWork& x, SeriesWork& y); /** * @brief P=S*P * @note Equivalent to y.push_back(&x) */ static inline ParallelWork& operator*(SeriesWork& x, ParallelWork& y); /** * @brief P=P*t * @note Equivalent to x.add_series(Workflow::create_series_work(&y, nullptr)) */ static inline ParallelWork& operator*(ParallelWork& x, SubTask& y); /** * @brief P=P*t * @note Equivalent to x.add_series(Workflow::create_series_work(y, nullptr)) */ static inline ParallelWork& operator*(ParallelWork& x, SubTask *y); /** * @brief P=t*P * @note Equivalent to y.add_series(Workflow::create_series_work(&x, nullptr)) */ static inline ParallelWork& operator*(SubTask& x, ParallelWork& y); /** * @brief P=t*P * @note Equivalent to y.add_series(Workflow::create_series_work(x, nullptr)) */ static inline ParallelWork& operator*(SubTask *x, ParallelWork& y); /** * @brief P=S*S * @note Equivalent to Workflow::create_parallel_work({&x, &y}, 2, nullptr) */ static inline ParallelWork& operator*(SeriesWork& x, SeriesWork& y); /** * @brief P=S*t * @note Equivalent to Workflow::create_parallel_work({&y}, 1, nullptr)->add_series(&x) */ static inline ParallelWork& operator*(SeriesWork& x, SubTask& y); /** * @brief P=S*t * @note Equivalent to Workflow::create_parallel_work({y}, 1, nullptr)->add_series(&x) */ static inline ParallelWork& operator*(SeriesWork& x, SubTask *y); /** * @brief P=t*S * @note Equivalent to Workflow::create_parallel_work({&x}, 1, nullptr)->add_series(&y) */ static inline ParallelWork& operator*(SubTask& x, SeriesWork& y); /** * @brief P=t*S * @note Equivalent to Workflow::create_parallel_work({x}, 1, nullptr)->add_series(&y) */ static inline ParallelWork& operator*(SubTask *x, SeriesWork& y); /** * @brief P=t*t * @note Equivalent to Workflow::create_parallel_work({Workflow::create_series_work(&x, nullptr), Workflow::create_series_work(&y, nullptr)}, 2, nullptr) */ static inline ParallelWork& operator*(SubTask& x, SubTask& y); /** * @brief P=t*t * @note Equivalent to Workflow::create_parallel_work({Workflow::create_series_work(&x, nullptr), Workflow::create_series_work(y, nullptr)}, 2, nullptr) */ static inline ParallelWork& operator*(SubTask& x, SubTask *y); /** * @brief P=t*t * @note Equivalent to Workflow::create_parallel_work({Workflow::create_series_work(x, nullptr), Workflow::create_series_work(&y, nullptr)}, 2, nullptr) */ static inline ParallelWork& operator*(SubTask *x, SubTask& y); //static inline ParallelWork& operator*(SubTask *x, SubTask *y);//compile error! //S=S>t static inline SeriesWork& operator>(SeriesWork& x, SubTask& y) { x.push_back(&y); return x; } static inline SeriesWork& operator>(SeriesWork& x, SubTask *y) { return x > *y; } //S=t>S static inline SeriesWork& operator>(SubTask& x, SeriesWork& y) { y.push_front(&x); return y; } static inline SeriesWork& operator>(SubTask *x, SeriesWork& y) { return *x > y; } //S=t>t static inline SeriesWork& operator>(SubTask& x, SubTask& y) { SeriesWork *series = Workflow::create_series_work(&x, nullptr); series->push_back(&y); return *series; } static inline SeriesWork& operator>(SubTask& x, SubTask *y) { return x > *y; } static inline SeriesWork& operator>(SubTask *x, SubTask& y) { return *x > y; } //S=S>P static inline SeriesWork& operator>(SeriesWork& x, ParallelWork& y) { return x > (SubTask&)y; } //S=P>S static inline SeriesWork& operator>(ParallelWork& x, SeriesWork& y) { return y > (SubTask&)x; } //S=P>P static inline SeriesWork& operator>(ParallelWork& x, ParallelWork& y) { return (SubTask&)x > (SubTask&)y; } //S=P>t static inline SeriesWork& operator>(ParallelWork& x, SubTask& y) { return (SubTask&)x > y; } static inline SeriesWork& operator>(ParallelWork& x, SubTask *y) { return x > *y; } //S=t>P static inline SeriesWork& operator>(SubTask& x, ParallelWork& y) { return x > (SubTask&)y; } static inline SeriesWork& operator>(SubTask *x, ParallelWork& y) { return *x > y; } //P=P*S static inline ParallelWork& operator*(ParallelWork& x, SeriesWork& y) { x.add_series(&y); return x; } //P=S*P static inline ParallelWork& operator*(SeriesWork& x, ParallelWork& y) { return y * x; } //P=P*t static inline ParallelWork& operator*(ParallelWork& x, SubTask& y) { x.add_series(Workflow::create_series_work(&y, nullptr)); return x; } static inline ParallelWork& operator*(ParallelWork& x, SubTask *y) { return x * (*y); } //P=t*P static inline ParallelWork& operator*(SubTask& x, ParallelWork& y) { return y * x; } static inline ParallelWork& operator*(SubTask *x, ParallelWork& y) { return (*x) * y; } //P=S*S static inline ParallelWork& operator*(SeriesWork& x, SeriesWork& y) { SeriesWork *arr[2] = {&x, &y}; return *Workflow::create_parallel_work(arr, 2, nullptr); } //P=S*t static inline ParallelWork& operator*(SeriesWork& x, SubTask& y) { return x * (*Workflow::create_series_work(&y, nullptr)); } static inline ParallelWork& operator*(SeriesWork& x, SubTask *y) { return x * (*y); } //P=t*S static inline ParallelWork& operator*(SubTask& x, SeriesWork& y) { return y * x; } static inline ParallelWork& operator*(SubTask *x, SeriesWork& y) { return (*x) * y; } //P=t*t static inline ParallelWork& operator*(SubTask& x, SubTask& y) { SeriesWork *arr[2] = {Workflow::create_series_work(&x, nullptr), Workflow::create_series_work(&y, nullptr)}; return *Workflow::create_parallel_work(arr, 2, nullptr); } static inline ParallelWork& operator*(SubTask& x, SubTask *y) { return x * (*y); } static inline ParallelWork& operator*(SubTask *x, SubTask& y) { return (*x) * y; } #endif workflow-0.11.8/src/factory/WFResourcePool.cc000066400000000000000000000050061476003635400211070ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) Xie Han (xiehan@sogou-inc.com) */ #include #include "list.h" #include "WFTask.h" #include "WFResourcePool.h" class __RPConditional : public WFConditional { public: struct list_head list; struct WFResourcePool::Data *data; public: virtual void dispatch(); virtual void signal(void *res) { } public: __RPConditional(SubTask *task, void **resbuf, struct WFResourcePool::Data *data) : WFConditional(task, resbuf) { this->data = data; } __RPConditional(SubTask *task, struct WFResourcePool::Data *data) : WFConditional(task) { this->data = data; } }; void __RPConditional::dispatch() { struct WFResourcePool::Data *data = this->data; data->mutex.lock(); if (--data->value >= 0) this->WFConditional::signal(data->pop()); else list_add_tail(&this->list, &data->wait_list); data->mutex.unlock(); this->WFConditional::dispatch(); } WFConditional *WFResourcePool::get(SubTask *task, void **resbuf) { return new __RPConditional(task, resbuf, &this->data); } WFConditional *WFResourcePool::get(SubTask *task) { return new __RPConditional(task, &this->data); } void WFResourcePool::create(size_t n) { this->data.res = new void *[n]; this->data.value = n; this->data.index = 0; INIT_LIST_HEAD(&this->data.wait_list); this->data.pool = this; } WFResourcePool::WFResourcePool(void *const *res, size_t n) { this->create(n); memcpy(this->data.res, res, n * sizeof (void *)); } WFResourcePool::WFResourcePool(size_t n) { this->create(n); memset(this->data.res, 0, n * sizeof (void *)); } void WFResourcePool::post(void *res) { struct WFResourcePool::Data *data = &this->data; WFConditional *cond; data->mutex.lock(); if (++data->value <= 0) { cond = list_entry(data->wait_list.next, __RPConditional, list); list_del(data->wait_list.next); } else { cond = NULL; this->push(res); } data->mutex.unlock(); if (cond) cond->WFConditional::signal(res); } workflow-0.11.8/src/factory/WFResourcePool.h000066400000000000000000000030131476003635400207450ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) Xie Han (xiehan@sogou-inc.com) */ #ifndef _WFRESOURCEPOOL_H_ #define _WFRESOURCEPOOL_H_ #include #include "list.h" #include "WFTask.h" class WFResourcePool { public: WFConditional *get(SubTask *task, void **resbuf); WFConditional *get(SubTask *task); void post(void *res); public: struct Data { void *pop() { return this->pool->pop(); } void push(void *res) { this->pool->push(res); } void **res; long value; size_t index; struct list_head wait_list; std::mutex mutex; WFResourcePool *pool; }; protected: virtual void *pop() { return this->data.res[this->data.index++]; } virtual void push(void *res) { this->data.res[--this->data.index] = res; } protected: struct Data data; private: void create(size_t n); public: WFResourcePool(void *const *res, size_t n); WFResourcePool(size_t n); virtual ~WFResourcePool() { delete []this->data.res; } }; #endif workflow-0.11.8/src/factory/WFTask.h000066400000000000000000000374341476003635400172440ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _WFTASK_H_ #define _WFTASK_H_ #include #include #include #include #include #include #include "Executor.h" #include "ExecRequest.h" #include "Communicator.h" #include "CommScheduler.h" #include "CommRequest.h" #include "SleepRequest.h" #include "IORequest.h" #include "Workflow.h" #include "WFConnection.h" enum { WFT_STATE_UNDEFINED = -1, WFT_STATE_SUCCESS = CS_STATE_SUCCESS, WFT_STATE_TOREPLY = CS_STATE_TOREPLY, /* for server task only */ WFT_STATE_NOREPLY = CS_STATE_TOREPLY + 1, /* for server task only */ WFT_STATE_SYS_ERROR = CS_STATE_ERROR, WFT_STATE_SSL_ERROR = 65, WFT_STATE_DNS_ERROR = 66, /* for client task only */ WFT_STATE_TASK_ERROR = 67, WFT_STATE_ABORTED = CS_STATE_STOPPED }; template class WFThreadTask : public ExecRequest { public: void start() { assert(!series_of(this)); Workflow::start_series_work(this, nullptr); } void dismiss() { assert(!series_of(this)); delete this; } public: INPUT *get_input() { return &this->input; } OUTPUT *get_output() { return &this->output; } public: void *user_data; public: int get_state() const { return this->state; } int get_error() const { return this->error; } public: void set_callback(std::function *)> cb) { this->callback = std::move(cb); } protected: virtual SubTask *done() { SeriesWork *series = series_of(this); if (this->callback) this->callback(this); delete this; return series->pop(); } protected: INPUT input; OUTPUT output; std::function *)> callback; public: WFThreadTask(ExecQueue *queue, Executor *executor, std::function *)>&& cb) : ExecRequest(queue, executor), callback(std::move(cb)) { this->user_data = NULL; this->state = WFT_STATE_UNDEFINED; this->error = 0; } protected: virtual ~WFThreadTask() { } }; template class WFNetworkTask : public CommRequest { public: /* start(), dismiss() are for client tasks only. */ void start() { assert(!series_of(this)); Workflow::start_series_work(this, nullptr); } void dismiss() { assert(!series_of(this)); delete this; } public: REQ *get_req() { return &this->req; } RESP *get_resp() { return &this->resp; } public: void *user_data; public: int get_state() const { return this->state; } int get_error() const { return this->error; } /* Call when error is ETIMEDOUT, return values: * TOR_NOT_TIMEOUT, TOR_WAIT_TIMEOUT, TOR_CONNECT_TIMEOUT, * TOR_TRANSMIT_TIMEOUT (send or receive). * SSL connect timeout also returns TOR_CONNECT_TIMEOUT. */ int get_timeout_reason() const { return this->timeout_reason; } /* Call only in callback or server's process. */ long long get_task_seq() const { if (!this->target) { errno = ENOTCONN; return -1; } return this->get_seq(); } int get_peer_addr(struct sockaddr *addr, socklen_t *addrlen) const; virtual WFConnection *get_connection() const = 0; public: /* All in milliseconds. timeout == -1 for unlimited. */ void set_send_timeout(int timeout) { this->send_timeo = timeout; } void set_receive_timeout(int timeout) { this->receive_timeo = timeout; } void set_keep_alive(int timeout) { this->keep_alive_timeo = timeout; } void set_watch_timeout(int timeout) { this->watch_timeo = timeout; } public: /* Do not reply this request. */ void noreply() { if (this->state == WFT_STATE_TOREPLY) this->state = WFT_STATE_NOREPLY; } /* Push reply data synchronously. */ virtual int push(const void *buf, size_t size) { if (this->state != WFT_STATE_TOREPLY && this->state != WFT_STATE_NOREPLY) { errno = ENOENT; return -1; } return this->scheduler->push(buf, size, this); } /* To check if the connection was closed before replying. Always returns 'true' in callback. */ bool closed() const { switch (this->state) { case WFT_STATE_UNDEFINED: return false; case WFT_STATE_TOREPLY: case WFT_STATE_NOREPLY: return !this->target->has_idle_conn(); default: return true; } } public: void set_prepare(std::function *)> prep) { this->prepare = std::move(prep); } public: void set_callback(std::function *)> cb) { this->callback = std::move(cb); } protected: virtual int send_timeout() { return this->send_timeo; } virtual int receive_timeout() { return this->receive_timeo; } virtual int keep_alive_timeout() { return this->keep_alive_timeo; } virtual int first_timeout() { return this->watch_timeo; } protected: int send_timeo; int receive_timeo; int keep_alive_timeo; int watch_timeo; REQ req; RESP resp; std::function *)> prepare; std::function *)> callback; protected: WFNetworkTask(CommSchedObject *object, CommScheduler *scheduler, std::function *)>&& cb) : CommRequest(object, scheduler), callback(std::move(cb)) { this->send_timeo = -1; this->receive_timeo = -1; this->keep_alive_timeo = 0; this->watch_timeo = 0; this->target = NULL; this->timeout_reason = TOR_NOT_TIMEOUT; this->user_data = NULL; this->state = WFT_STATE_UNDEFINED; this->error = 0; } virtual ~WFNetworkTask() { } }; class WFTimerTask : public SleepRequest { public: void start() { assert(!series_of(this)); Workflow::start_series_work(this, nullptr); } void dismiss() { assert(!series_of(this)); delete this; } public: void *user_data; public: int get_state() const { return this->state; } int get_error() const { return this->error; } public: void set_callback(std::function cb) { this->callback = std::move(cb); } protected: virtual SubTask *done() { SeriesWork *series = series_of(this); if (this->callback) this->callback(this); delete this; return series->pop(); } protected: std::function callback; public: WFTimerTask(CommScheduler *scheduler, std::function cb) : SleepRequest(scheduler), callback(std::move(cb)) { this->user_data = NULL; this->state = WFT_STATE_UNDEFINED; this->error = 0; } protected: virtual ~WFTimerTask() { } }; template class WFFileTask : public IORequest { public: void start() { assert(!series_of(this)); Workflow::start_series_work(this, nullptr); } void dismiss() { assert(!series_of(this)); delete this; } public: ARGS *get_args() { return &this->args; } long get_retval() const { if (this->state == WFT_STATE_SUCCESS) return this->get_res(); else return -1; } public: void *user_data; public: int get_state() const { return this->state; } int get_error() const { return this->error; } public: void set_callback(std::function *)> cb) { this->callback = std::move(cb); } protected: virtual SubTask *done() { SeriesWork *series = series_of(this); if (this->callback) this->callback(this); delete this; return series->pop(); } protected: ARGS args; std::function *)> callback; public: WFFileTask(IOService *service, std::function *)>&& cb) : IORequest(service), callback(std::move(cb)) { this->user_data = NULL; this->state = WFT_STATE_UNDEFINED; this->error = 0; } protected: virtual ~WFFileTask() { } }; class WFGenericTask : public SubTask { public: void start() { assert(!series_of(this)); Workflow::start_series_work(this, nullptr); } void dismiss() { assert(!series_of(this)); delete this; } public: void *user_data; public: int get_state() const { return this->state; } int get_error() const { return this->error; } protected: virtual void dispatch() { this->subtask_done(); } virtual SubTask *done() { SeriesWork *series = series_of(this); delete this; return series->pop(); } protected: int state; int error; public: WFGenericTask() { this->user_data = NULL; this->state = WFT_STATE_UNDEFINED; this->error = 0; } protected: virtual ~WFGenericTask() { } }; class WFCounterTask : public WFGenericTask { public: virtual void count() { if (--this->value == 0) { this->state = WFT_STATE_SUCCESS; this->subtask_done(); } } public: void set_callback(std::function cb) { this->callback = std::move(cb); } protected: virtual void dispatch() { this->WFCounterTask::count(); } virtual SubTask *done() { SeriesWork *series = series_of(this); if (this->callback) this->callback(this); delete this; return series->pop(); } protected: std::atomic value; std::function callback; public: WFCounterTask(unsigned int target_value, std::function&& cb) : value(target_value + 1), callback(std::move(cb)) { } protected: virtual ~WFCounterTask() { } }; class WFMailboxTask : public WFGenericTask { public: virtual void send(void *msg) { *this->mailbox = msg; if (this->flag.exchange(true)) { this->state = WFT_STATE_SUCCESS; this->subtask_done(); } } void **get_mailbox() const { return this->mailbox; } public: void set_callback(std::function cb) { this->callback = std::move(cb); } protected: virtual void dispatch() { if (this->flag.exchange(true)) { this->state = WFT_STATE_SUCCESS; this->subtask_done(); } } virtual SubTask *done() { SeriesWork *series = series_of(this); if (this->callback) this->callback(this); delete this; return series->pop(); } protected: void **mailbox; std::atomic flag; std::function callback; public: WFMailboxTask(void **mailbox, std::function&& cb) : flag(false), callback(std::move(cb)) { this->mailbox = mailbox; } WFMailboxTask(std::function&& cb) : flag(false), callback(std::move(cb)) { this->mailbox = &this->user_data; } protected: virtual ~WFMailboxTask() { } }; class WFSelectorTask : public WFGenericTask { public: virtual int submit(void *msg) { void *tmp = NULL; int ret = 0; if (this->message.compare_exchange_strong(tmp, msg) && msg) { ret = 1; if (this->flag.exchange(true)) { this->state = WFT_STATE_SUCCESS; this->subtask_done(); } } if (--this->nleft == 0) { if (!this->message) { this->state = WFT_STATE_SYS_ERROR; this->error = ENOMSG; this->subtask_done(); } delete this; } return ret; } void *get_message() const { return this->message; } public: void set_callback(std::function cb) { this->callback = std::move(cb); } protected: virtual void dispatch() { if (this->flag.exchange(true)) { this->state = WFT_STATE_SUCCESS; this->subtask_done(); } if (--this->nleft == 0) { if (!this->message) { this->state = WFT_STATE_SYS_ERROR; this->error = ENOMSG; this->subtask_done(); } delete this; } } virtual SubTask *done() { SeriesWork *series = series_of(this); if (this->callback) this->callback(this); return series->pop(); } protected: std::atomic message; std::atomic flag; std::atomic nleft; std::function callback; public: WFSelectorTask(size_t candidates, std::function&& cb) : message(NULL), flag(false), nleft(candidates + 1), callback(std::move(cb)) { } protected: virtual ~WFSelectorTask() { } }; class WFConditional : public WFGenericTask { public: virtual void signal(void *msg) { *this->msgbuf = msg; if (this->flag.exchange(true)) this->subtask_done(); } protected: virtual void dispatch() { series_of(this)->push_front(this->task); this->task = NULL; if (this->flag.exchange(true)) this->subtask_done(); } protected: std::atomic flag; SubTask *task; void **msgbuf; public: WFConditional(SubTask *task, void **msgbuf) : flag(false) { this->task = task; this->msgbuf = msgbuf; } WFConditional(SubTask *task) : flag(false) { this->task = task; this->msgbuf = &this->user_data; } protected: virtual ~WFConditional() { delete this->task; } }; class WFGoTask : public ExecRequest { public: void start() { assert(!series_of(this)); Workflow::start_series_work(this, nullptr); } void dismiss() { assert(!series_of(this)); delete this; } public: void *user_data; public: int get_state() const { return this->state; } int get_error() const { return this->error; } public: void set_callback(std::function cb) { this->callback = std::move(cb); } protected: virtual SubTask *done() { SeriesWork *series = series_of(this); if (this->callback) this->callback(this); delete this; return series->pop(); } protected: std::function callback; public: WFGoTask(ExecQueue *queue, Executor *executor) : ExecRequest(queue, executor) { this->user_data = NULL; this->state = WFT_STATE_UNDEFINED; this->error = 0; } protected: virtual ~WFGoTask() { } }; class WFRepeaterTask : public WFGenericTask { public: void set_create(std::function create) { this->create = std::move(create); } public: void set_callback(std::function cb) { this->callback = std::move(cb); } protected: virtual void dispatch() { SubTask *task = this->create(this); if (task) { series_of(this)->push_front(this); series_of(this)->push_front(task); } else this->state = WFT_STATE_SUCCESS; this->subtask_done(); } virtual SubTask *done() { SeriesWork *series = series_of(this); if (this->state != WFT_STATE_UNDEFINED) { if (this->callback) this->callback(this); delete this; } return series->pop(); } protected: std::function create; std::function callback; public: WFRepeaterTask(std::function&& create, std::function&& cb) : create(std::move(create)), callback(std::move(cb)) { } protected: virtual ~WFRepeaterTask() { } }; class WFModuleTask : public ParallelTask, protected SeriesWork { public: void start() { assert(!series_of(this)); Workflow::start_series_work(this, nullptr); } void dismiss() { assert(!series_of(this)); delete this; } public: SeriesWork *sub_series() { return this; } const SeriesWork *sub_series() const { return this; } public: void *user_data; public: void set_callback(std::function cb) { this->callback = std::move(cb); } protected: virtual SubTask *done() { SeriesWork *series = series_of(this); if (this->callback) this->callback(this); delete this; return series->pop(); } protected: SubTask *first; std::function callback; public: WFModuleTask(SubTask *first, std::function&& cb) : ParallelTask(&this->first, 1), SeriesWork(first, nullptr), callback(std::move(cb)) { this->first = first; this->set_in_parallel(this); this->user_data = NULL; } protected: virtual ~WFModuleTask() { if (!this->is_finished()) this->dismiss_recursive(); } }; #include "WFTask.inl" #endif workflow-0.11.8/src/factory/WFTask.inl000066400000000000000000000124411476003635400175660ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ template int WFNetworkTask::get_peer_addr(struct sockaddr *addr, socklen_t *addrlen) const { const struct sockaddr *p; socklen_t len; if (this->target) { this->target->get_addr(&p, &len); if (*addrlen >= len) { memcpy(addr, p, len); *addrlen = len; return 0; } errno = ENOBUFS; } else errno = ENOTCONN; return -1; } template class WFClientTask : public WFNetworkTask { protected: virtual CommMessageOut *message_out() { /* By using prepare function, users can modify the request after * the connection is established. */ if (this->prepare) this->prepare(this); return &this->req; } virtual CommMessageIn *message_in() { return &this->resp; } protected: virtual WFConnection *get_connection() const { CommConnection *conn; if (this->target) { conn = this->CommSession::get_connection(); if (conn) return (WFConnection *)conn; } errno = ENOTCONN; return NULL; } protected: virtual SubTask *done() { SeriesWork *series = series_of(this); if (this->state == WFT_STATE_SYS_ERROR && this->error < 0) { this->state = WFT_STATE_SSL_ERROR; this->error = -this->error; } if (this->callback) this->callback(this); delete this; return series->pop(); } public: WFClientTask(CommSchedObject *object, CommScheduler *scheduler, std::function *)>&& cb) : WFNetworkTask(object, scheduler, std::move(cb)) { } protected: virtual ~WFClientTask() { } }; template class WFServerTask : public WFNetworkTask { protected: virtual CommMessageOut *message_out() { /* By using prepare function, users can modify the response before * replying to the client. */ if (this->prepare) this->prepare(this); return &this->resp; } virtual CommMessageIn *message_in() { return &this->req; } virtual void handle(int state, int error); protected: /* CommSession::get_connection() is supposed to be called only in the * implementations of it's virtual functions. As a server task, to call * this function after process() and before callback() is very dangerous * and should be blocked. */ virtual WFConnection *get_connection() const { if (this->processor.task) return (WFConnection *)this->CommSession::get_connection(); errno = EPERM; return NULL; } protected: virtual void dispatch() { if (this->state == WFT_STATE_TOREPLY) { /* Enable get_connection() again if the reply() call is success. */ this->processor.task = this; if (this->scheduler->reply(this) >= 0) return; this->state = WFT_STATE_SYS_ERROR; this->error = errno; this->processor.task = NULL; } else this->scheduler->shutdown(this); this->subtask_done(); } virtual SubTask *done() { SeriesWork *series = series_of(this); if (this->state == WFT_STATE_SYS_ERROR && this->error < 0) { this->state = WFT_STATE_SSL_ERROR; this->error = -this->error; } if (this->callback) this->callback(this); /* Defer deleting the task. */ return series->pop(); } protected: class Processor : public SubTask { public: Processor(WFServerTask *task, std::function *)>& proc) : process(proc) { this->task = task; } virtual void dispatch() { this->process(this->task); this->task = NULL; /* As a flag. get_conneciton() disabled. */ this->subtask_done(); } virtual SubTask *done() { return series_of(this)->pop(); } std::function *)>& process; WFServerTask *task; } processor; class Series : public SeriesWork { public: Series(WFServerTask *task) : SeriesWork(&task->processor, nullptr) { this->set_last_task(task); this->task = task; } virtual ~Series() { delete this->task; } WFServerTask *task; }; public: WFServerTask(CommService *service, CommScheduler *scheduler, std::function *)>& proc) : WFNetworkTask(NULL, scheduler, nullptr), processor(this, proc) { } protected: virtual ~WFServerTask() { if (this->target) ((Series *)series_of(this))->task = NULL; } }; template void WFServerTask::handle(int state, int error) { if (state == WFT_STATE_TOREPLY) { this->state = WFT_STATE_TOREPLY; this->target = this->get_target(); new Series(this); this->processor.dispatch(); } else if (this->state == WFT_STATE_TOREPLY) { this->state = state; this->error = error; if (error == ETIMEDOUT) this->timeout_reason = TOR_TRANSMIT_TIMEOUT; this->subtask_done(); } else delete this; } workflow-0.11.8/src/factory/WFTaskError.h000066400000000000000000000056571476003635400202600ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #ifndef _WFTASKERROR_H_ #define _WFTASKERROR_H_ /** * @file WFTaskError.h * @brief Workflow Task Error Code List */ /** * @brief Defination of task state code * @note Only for WFNetworkTask and only when get_state()==WFT_STATE_TASK_ERROR */ enum { //COMMON WFT_ERR_URI_PARSE_FAILED = 1001, ///< URI, parse failed WFT_ERR_URI_SCHEME_INVALID = 1002, ///< URI, invalid scheme WFT_ERR_URI_PORT_INVALID = 1003, ///< URI, invalid port WFT_ERR_UPSTREAM_UNAVAILABLE = 1004, ///< Upstream, all target server down //HTTP WFT_ERR_HTTP_BAD_REDIRECT_HEADER = 2001, ///< Http, 301/302/303/307/308 Location header value is NULL WFT_ERR_HTTP_PROXY_CONNECT_FAILED = 2002, ///< Http, proxy CONNECT return non 200 //REDIS WFT_ERR_REDIS_ACCESS_DENIED = 3001, ///< Redis, invalid password WFT_ERR_REDIS_COMMAND_DISALLOWED = 3002, ///< Redis, command disabled, cannot be "AUTH"/"SELECT" //MYSQL WFT_ERR_MYSQL_HOST_NOT_ALLOWED = 4001, ///< MySQL WFT_ERR_MYSQL_ACCESS_DENIED = 4002, ///< MySQL, authentication failed WFT_ERR_MYSQL_INVALID_CHARACTER_SET = 4003, ///< MySQL, invalid charset, not found in MySQL-Documentation WFT_ERR_MYSQL_COMMAND_DISALLOWED = 4004, ///< MySQL, sql command disabled, cannot be "USE"/"SET NAMES"/"SET CHARSET"/"SET CHARACTER SET" WFT_ERR_MYSQL_QUERY_NOT_SET = 4005, ///< MySQL, query not set sql, maybe forget please check WFT_ERR_MYSQL_SSL_NOT_SUPPORTED = 4006, ///< MySQL, SSL not supported by the server //KAFKA WFT_ERR_KAFKA_PARSE_RESPONSE_FAILED = 5001, ///< Kafka parse response failed WFT_ERR_KAFKA_PRODUCE_FAILED = 5002, WFT_ERR_KAFKA_FETCH_FAILED = 5003, WFT_ERR_KAFKA_CGROUP_FAILED = 5004, WFT_ERR_KAFKA_COMMIT_FAILED = 5005, WFT_ERR_KAFKA_META_FAILED = 5006, WFT_ERR_KAFKA_LEAVEGROUP_FAILED = 5007, WFT_ERR_KAFKA_API_UNKNOWN = 5008, ///< api type not supported WFT_ERR_KAFKA_VERSION_DISALLOWED = 5009, ///< broker version not supported WFT_ERR_KAFKA_SASL_DISALLOWED = 5010, ///< sasl not supported WFT_ERR_KAFKA_ARRANGE_FAILED = 5011, ///< arrange toppar failed WFT_ERR_KAFKA_LIST_OFFSETS_FAILED = 5012, WFT_ERR_KAFKA_CGROUP_ASSIGN_FAILED = 5013, //CONSUL WFT_ERR_CONSUL_API_UNKNOWN = 6001, ///< api type not supported WFT_ERR_CONSUL_CHECK_RESPONSE_FAILED = 6002, ///< Consul http code failed }; #endif workflow-0.11.8/src/factory/WFTaskFactory.cc000066400000000000000000000567201476003635400207310ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include #include #include #include "list.h" #include "rbtree.h" #include "WFGlobal.h" #include "WFTaskFactory.h" class __WFTimerTask : public WFTimerTask { protected: virtual int duration(struct timespec *value) { value->tv_sec = this->seconds; value->tv_nsec = this->nanoseconds; return 0; } protected: time_t seconds; long nanoseconds; public: __WFTimerTask(time_t seconds, long nanoseconds, CommScheduler *scheduler, timer_callback_t&& cb) : WFTimerTask(scheduler, std::move(cb)) { this->seconds = seconds; this->nanoseconds = nanoseconds; } }; WFTimerTask *WFTaskFactory::create_timer_task(time_t seconds, long nanoseconds, timer_callback_t callback) { return new __WFTimerTask(seconds, nanoseconds, WFGlobal::get_scheduler(), std::move(callback)); } /* Deprecated. */ WFTimerTask *WFTaskFactory::create_timer_task(unsigned int microseconds, timer_callback_t callback) { return WFTaskFactory::create_timer_task(microseconds / 1000000, microseconds % 1000000 * 1000, std::move(callback)); } /****************** Named Tasks ******************/ template struct __NamedObjectList { __NamedObjectList(const std::string& str): name(str) { INIT_LIST_HEAD(&this->head); } void push_back(T *node) { list_add_tail(&node->list, &this->head); } bool empty() const { return list_empty(&this->head); } bool del(T *node, rb_root *root) { list_del(&node->list); if (this->empty()) { rb_erase(&this->rb, root); return true; } else return false; } struct rb_node rb; struct list_head head; std::string name; }; template static T *__get_object_list(const std::string& name, struct rb_root *root, bool insert) { struct rb_node **p = &root->rb_node; struct rb_node *parent = NULL; T *objs; int n; while (*p) { parent = *p; objs = rb_entry(*p, T, rb); n = name.compare(objs->name); if (n < 0) p = &(*p)->rb_left; else if (n > 0) p = &(*p)->rb_right; else return objs; } if (insert) { objs = new T(name); rb_link_node(&objs->rb, parent, p); rb_insert_color(&objs->rb, root); return objs; } return NULL; } /****************** Named Timer ******************/ class __WFNamedTimerTask; struct __timer_node { struct list_head list; __WFNamedTimerTask *task; }; static class __NamedTimerMap { public: using TimerList = __NamedObjectList; public: WFTimerTask *create(const std::string& name, time_t seconds, long nanoseconds, CommScheduler *scheduler, timer_callback_t&& cb); public: int cancel(const std::string& name, size_t max); private: struct rb_root root_; std::mutex mutex_; public: __NamedTimerMap() { root_.rb_node = NULL; } friend class __WFNamedTimerTask; } __timer_map; class __WFNamedTimerTask : public __WFTimerTask { public: __WFNamedTimerTask(time_t seconds, long nanoseconds, CommScheduler *scheduler, timer_callback_t&& cb) : __WFTimerTask(seconds, nanoseconds, scheduler, std::move(cb)), flag_(false) { node_.task = this; } void push_to(__NamedTimerMap::TimerList *timers) { timers->push_back(&node_); timers_ = timers; } virtual ~__WFNamedTimerTask() { if (node_.task) { bool erased = false; __timer_map.mutex_.lock(); if (node_.task) erased = timers_->del(&node_, &__timer_map.root_); __timer_map.mutex_.unlock(); if (erased) delete timers_; } } protected: virtual void dispatch(); virtual void handle(int state, int error); private: struct __timer_node node_; __NamedTimerMap::TimerList *timers_; std::atomic flag_; std::mutex mutex_; friend class __NamedTimerMap; }; void __WFNamedTimerTask::dispatch() { int ret; mutex_.lock(); ret = this->scheduler->sleep(this); if (ret >= 0 && flag_.exchange(true)) this->cancel(); mutex_.unlock(); if (ret < 0) this->handle(WFT_STATE_SYS_ERROR, errno); } void __WFNamedTimerTask::handle(int state, int error) { bool canceled = true; if (node_.task) { bool erased = false; __timer_map.mutex_.lock(); if (node_.task) { canceled = false; erased = timers_->del(&node_, &__timer_map.root_); node_.task = NULL; } __timer_map.mutex_.unlock(); if (erased) delete timers_; } if (canceled) { state = WFT_STATE_SYS_ERROR; error = ECANCELED; } mutex_.lock(); mutex_.unlock(); this->__WFTimerTask::handle(state, error); } WFTimerTask *__NamedTimerMap::create(const std::string& name, time_t seconds, long nanoseconds, CommScheduler *scheduler, timer_callback_t&& cb) { auto *task = new __WFNamedTimerTask(seconds, nanoseconds, scheduler, std::move(cb)); mutex_.lock(); task->push_to(__get_object_list(name, &root_, true)); mutex_.unlock(); return task; } int __NamedTimerMap::cancel(const std::string& name, size_t max) { struct __timer_node *node; TimerList *timers; int ret = 0; mutex_.lock(); timers = __get_object_list(name, &root_, false); if (timers) { while (1) { if (max == 0) { timers = NULL; break; } node = list_entry(timers->head.next, struct __timer_node, list); list_del(&node->list); if (node->task->flag_.exchange(true)) node->task->cancel(); node->task = NULL; max--; ret++; if (timers->empty()) { rb_erase(&timers->rb, &root_); break; } } } mutex_.unlock(); delete timers; return ret; } WFTimerTask *WFTaskFactory::create_timer_task(const std::string& name, time_t seconds, long nanoseconds, timer_callback_t callback) { return __timer_map.create(name, seconds, nanoseconds, WFGlobal::get_scheduler(), std::move(callback)); } int WFTaskFactory::cancel_by_name(const std::string& name, size_t max) { return __timer_map.cancel(name, max); } /****************** Named Counter ******************/ class __WFNamedCounterTask; struct __counter_node { struct list_head list; unsigned int target_value; __WFNamedCounterTask *task; }; static class __NamedCounterMap { public: using CounterList = __NamedObjectList; public: WFCounterTask *create(const std::string& name, unsigned int target_value, counter_callback_t&& cb); int count_n(const std::string& name, unsigned int n); void count(CounterList *counters, struct __counter_node *node); void remove(CounterList *counters, struct __counter_node *node) { bool erased; mutex_.lock(); erased = counters->del(node, &root_); mutex_.unlock(); if (erased) delete counters; } private: bool count_n_locked(CounterList *counters, unsigned int n, struct list_head *task_list); struct rb_root root_; std::mutex mutex_; public: __NamedCounterMap() { root_.rb_node = NULL; } } __counter_map; class __WFNamedCounterTask : public WFCounterTask { public: __WFNamedCounterTask(unsigned int target_value, counter_callback_t&& cb) : WFCounterTask(1, std::move(cb)) { node_.target_value = target_value; node_.task = this; } void push_to(__NamedCounterMap::CounterList *counters) { counters->push_back(&node_); counters_ = counters; } virtual void count() { __counter_map.count(counters_, &node_); } virtual ~__WFNamedCounterTask() { if (this->value != 0) __counter_map.remove(counters_, &node_); } private: struct __counter_node node_; __NamedCounterMap::CounterList *counters_; }; WFCounterTask *__NamedCounterMap::create(const std::string& name, unsigned int target_value, counter_callback_t&& cb) { if (target_value == 0) return new WFCounterTask(0, std::move(cb)); auto *task = new __WFNamedCounterTask(target_value, std::move(cb)); mutex_.lock(); task->push_to(__get_object_list(name, &root_, true)); mutex_.unlock(); return task; } bool __NamedCounterMap::count_n_locked(CounterList *counters, unsigned int n, struct list_head *task_list) { struct __counter_node *node; while (n > 0) { node = list_entry(counters->head.next, struct __counter_node, list); if (n >= node->target_value) { n -= node->target_value; node->target_value = 0; list_move_tail(&node->list, task_list); if (counters->empty()) { rb_erase(&counters->rb, &root_); return true; } } else { node->target_value -= n; break; } } return false; } int __NamedCounterMap::count_n(const std::string& name, unsigned int n) { LIST_HEAD(task_list); struct __counter_node *node; CounterList *counters; bool erased = false; int ret = 0; mutex_.lock(); counters = __get_object_list(name, &root_, false); if (counters) erased = count_n_locked(counters, n, &task_list); mutex_.unlock(); if (erased) delete counters; while (!list_empty(&task_list)) { node = list_entry(task_list.next, struct __counter_node, list); list_del(&node->list); node->task->WFCounterTask::count(); ret++; } return ret; } void __NamedCounterMap::count(CounterList *counters, struct __counter_node *node) { __WFNamedCounterTask *task = NULL; bool erased = false; mutex_.lock(); if (--node->target_value == 0) { task = node->task; erased = counters->del(node, &root_); } mutex_.unlock(); if (erased) delete counters; if (task) task->WFCounterTask::count(); } WFCounterTask *WFTaskFactory::create_counter_task(const std::string& name, unsigned int target_value, counter_callback_t callback) { return __counter_map.create(name, target_value, std::move(callback)); } int WFTaskFactory::count_by_name(const std::string& name, unsigned int n) { return __counter_map.count_n(name, n); } /****************** Named Mailbox ******************/ class __WFNamedMailboxTask; struct __mailbox_node { struct list_head list; __WFNamedMailboxTask *task; }; static class __NamedMailboxMap { public: using MailboxList = __NamedObjectList; public: WFMailboxTask *create(const std::string& name, void **mailbox, mailbox_callback_t&& cb); WFMailboxTask *create(const std::string& name, mailbox_callback_t&& cb); int send(const std::string& name, void *const msg[], size_t max, int inc); void send(MailboxList *mailboxes, struct __mailbox_node *node, void *msg); void remove(MailboxList *mailboxes, struct __mailbox_node *node) { bool erased; mutex_.lock(); erased = mailboxes->del(node, &root_); mutex_.unlock(); if (erased) delete mailboxes; } private: bool send_max_locked(MailboxList *mailboxes, size_t max, struct list_head *task_list); struct rb_root root_; std::mutex mutex_; public: __NamedMailboxMap() { root_.rb_node = NULL; } } __mailbox_map; class __WFNamedMailboxTask : public WFMailboxTask { public: __WFNamedMailboxTask(void **mailbox, mailbox_callback_t&& cb) : WFMailboxTask(mailbox, std::move(cb)) { node_.task = this; } __WFNamedMailboxTask(mailbox_callback_t&& cb) : WFMailboxTask(std::move(cb)) { node_.task = this; } void push_to(__NamedMailboxMap::MailboxList *mailboxes) { mailboxes->push_back(&node_); mailboxes_ = mailboxes; } virtual void send(void *msg) { __mailbox_map.send(mailboxes_, &node_, msg); } virtual ~__WFNamedMailboxTask() { if (!this->flag) __mailbox_map.remove(mailboxes_, &node_); } private: struct __mailbox_node node_; __NamedMailboxMap::MailboxList *mailboxes_; }; WFMailboxTask *__NamedMailboxMap::create(const std::string& name, void **mailbox, mailbox_callback_t&& cb) { auto *task = new __WFNamedMailboxTask(mailbox, std::move(cb)); mutex_.lock(); task->push_to(__get_object_list(name, &root_, true)); mutex_.unlock(); return task; } WFMailboxTask *__NamedMailboxMap::create(const std::string& name, mailbox_callback_t&& cb) { auto *task = new __WFNamedMailboxTask(std::move(cb)); mutex_.lock(); task->push_to(__get_object_list(name, &root_, true)); mutex_.unlock(); return task; } bool __NamedMailboxMap::send_max_locked(MailboxList *mailboxes, size_t max, struct list_head *task_list) { if (max == (size_t)-1) list_splice(&mailboxes->head, task_list); else { do { if (max == 0) return false; list_move_tail(mailboxes->head.next, task_list); max--; } while (!mailboxes->empty()); } rb_erase(&mailboxes->rb, &root_); return true; } int __NamedMailboxMap::send(const std::string& name, void *const msg[], size_t max, int inc) { LIST_HEAD(task_list); struct __mailbox_node *node; MailboxList *mailboxes; bool erased = false; int ret = 0; mutex_.lock(); mailboxes = __get_object_list(name, &root_, false); if (mailboxes) erased = send_max_locked(mailboxes, max, &task_list); mutex_.unlock(); if (erased) delete mailboxes; while (!list_empty(&task_list)) { node = list_entry(task_list.next, struct __mailbox_node, list); list_del(&node->list); node->task->WFMailboxTask::send(*msg); msg += inc; ret++; } return ret; } void __NamedMailboxMap::send(MailboxList *mailboxes, struct __mailbox_node *node, void *msg) { bool erased; mutex_.lock(); erased = mailboxes->del(node, &root_); mutex_.unlock(); if (erased) delete mailboxes; node->task->WFMailboxTask::send(msg); } WFMailboxTask *WFTaskFactory::create_mailbox_task(const std::string& name, void **mailbox, mailbox_callback_t callback) { return __mailbox_map.create(name, mailbox, std::move(callback)); } WFMailboxTask *WFTaskFactory::create_mailbox_task(const std::string& name, mailbox_callback_t callback) { return __mailbox_map.create(name, std::move(callback)); } int WFTaskFactory::send_by_name(const std::string& name, void *msg, size_t max) { return __mailbox_map.send(name, &msg, max, 0); } template<> int WFTaskFactory::send_by_name(const std::string& name, void *const msg[], size_t max) { return __mailbox_map.send(name, msg, max, 1); } /****************** Named Conditional ******************/ class __WFNamedConditional; struct __conditional_node { struct list_head list; __WFNamedConditional *cond; }; static class __NamedConditionalMap { public: using ConditionalList = __NamedObjectList; public: WFConditional *create(const std::string& name, SubTask *task, void **msgbuf); WFConditional *create(const std::string& name, SubTask *task); int signal(const std::string& name, void *const msg[], size_t max, int inc); void signal(ConditionalList *conds, struct __conditional_node *node, void *msg); void remove(ConditionalList *conds, struct __conditional_node *node) { bool erased; mutex_.lock(); erased = conds->del(node, &root_); mutex_.unlock(); if (erased) delete conds; } private: bool signal_max_locked(ConditionalList *conds, size_t max, struct list_head *cond_list); struct rb_root root_; std::mutex mutex_; public: __NamedConditionalMap() { root_.rb_node = NULL; } } __conditional_map; class __WFNamedConditional : public WFConditional { public: __WFNamedConditional(SubTask *task, void **msgbuf) : WFConditional(task, msgbuf) { node_.cond = this; } __WFNamedConditional(SubTask *task) : WFConditional(task) { node_.cond = this; } void push_to(__NamedConditionalMap::ConditionalList *conds) { conds->push_back(&node_); conds_ = conds; } virtual void signal(void *msg) { __conditional_map.signal(conds_, &node_, msg); } virtual ~__WFNamedConditional() { if (!this->flag) __conditional_map.remove(conds_, &node_); } private: struct __conditional_node node_; __NamedConditionalMap::ConditionalList *conds_; }; WFConditional *__NamedConditionalMap::create(const std::string& name, SubTask *task, void **msgbuf) { auto *cond = new __WFNamedConditional(task, msgbuf); mutex_.lock(); cond->push_to(__get_object_list(name, &root_, true)); mutex_.unlock(); return cond; } WFConditional *__NamedConditionalMap::create(const std::string& name, SubTask *task) { auto *cond = new __WFNamedConditional(task); mutex_.lock(); cond->push_to(__get_object_list(name, &root_, true)); mutex_.unlock(); return cond; } bool __NamedConditionalMap::signal_max_locked(ConditionalList *conds, size_t max, struct list_head *cond_list) { if (max == (size_t)-1) list_splice(&conds->head, cond_list); else { do { if (max == 0) return false; list_move_tail(conds->head.next, cond_list); max--; } while (!conds->empty()); } rb_erase(&conds->rb, &root_); return true; } int __NamedConditionalMap::signal(const std::string& name, void *const msg[], size_t max, int inc) { LIST_HEAD(cond_list); struct __conditional_node *node; ConditionalList *conds; bool erased = false; int ret = 0; mutex_.lock(); conds = __get_object_list(name, &root_, false); if (conds) erased = signal_max_locked(conds, max, &cond_list); mutex_.unlock(); if (erased) delete conds; while (!list_empty(&cond_list)) { node = list_entry(cond_list.next, struct __conditional_node, list); list_del(&node->list); node->cond->WFConditional::signal(*msg); msg += inc; ret++; } return ret; } void __NamedConditionalMap::signal(ConditionalList *conds, struct __conditional_node *node, void *msg) { bool erased; mutex_.lock(); erased = conds->del(node, &root_); mutex_.unlock(); if (erased) delete conds; node->cond->WFConditional::signal(msg); } WFConditional *WFTaskFactory::create_conditional(const std::string& name, SubTask *task, void **msgbuf) { return __conditional_map.create(name, task, msgbuf); } WFConditional *WFTaskFactory::create_conditional(const std::string& name, SubTask *task) { return __conditional_map.create(name, task); } int WFTaskFactory::signal_by_name(const std::string& name, void *msg, size_t max) { return __conditional_map.signal(name, &msg, max, 0); } template<> int WFTaskFactory::signal_by_name(const std::string& name, void *const msg[], size_t max) { return __conditional_map.signal(name, msg, max, 1); } /****************** Named Guard ******************/ class __WFNamedGuard; struct __guard_node { struct list_head list; __WFNamedGuard *guard; }; static class __NamedGuardMap { public: struct GuardList : public __NamedObjectList { GuardList(const std::string& name) : __NamedObjectList(name) { acquired = false; refcnt = 0; } bool acquired; size_t refcnt; std::mutex mutex; }; public: WFConditional *create(const std::string& name, SubTask *task); WFConditional *create(const std::string& name, SubTask *task, void **msgbuf); struct __guard_node *release(const std::string& name); void unref(GuardList *guards) { mutex_.lock(); if (--guards->refcnt == 0) rb_erase(&guards->rb, &root_); else guards = NULL; mutex_.unlock(); delete guards; } private: struct rb_root root_; std::mutex mutex_; public: __NamedGuardMap() { root_.rb_node = NULL; } } __guard_map; class __WFNamedGuard : public WFConditional { public: __WFNamedGuard(SubTask *task) : WFConditional(task) { node_.guard = this; } __WFNamedGuard(SubTask *task, void **msgbuf) : WFConditional(task, msgbuf) { node_.guard = this; } virtual ~__WFNamedGuard() { if (!this->flag) __guard_map.unref(guards_); } SubTask *get_task() const { return this->task; } void set_task(SubTask *task) { this->task = task; } protected: virtual void dispatch(); virtual void signal(void *msg) { } private: struct __guard_node node_; __NamedGuardMap::GuardList *guards_; friend __NamedGuardMap; }; void __WFNamedGuard::dispatch() { guards_->mutex.lock(); if (guards_->acquired) guards_->push_back(&node_); else { guards_->acquired = true; this->WFConditional::signal(NULL); } guards_->mutex.unlock(); this->WFConditional::dispatch(); } WFConditional *__NamedGuardMap::create(const std::string& name, SubTask *task) { auto *guard = new __WFNamedGuard(task); mutex_.lock(); guard->guards_ = __get_object_list(name, &root_, true); guard->guards_->refcnt++; mutex_.unlock(); return guard; } WFConditional *__NamedGuardMap::create(const std::string& name, SubTask *task, void **msgbuf) { auto *guard = new __WFNamedGuard(task, msgbuf); mutex_.lock(); guard->guards_ = __get_object_list(name, &root_, true); guard->guards_->refcnt++; mutex_.unlock(); return guard; } struct __guard_node *__NamedGuardMap::release(const std::string& name) { struct __guard_node *node = NULL; GuardList *guards; mutex_.lock(); guards = __get_object_list(name, &root_, false); if (guards) { if (--guards->refcnt == 0) rb_erase(&guards->rb, &root_); else { guards->mutex.lock(); if (!guards->empty()) { node = list_entry(guards->head.next, struct __guard_node, list); list_del(&node->list); } else guards->acquired = false; guards->mutex.unlock(); guards = NULL; } } mutex_.unlock(); delete guards; return node; } WFConditional *WFTaskFactory::create_guard(const std::string& name, SubTask *task) { return __guard_map.create(name, task); } WFConditional *WFTaskFactory::create_guard(const std::string& name, SubTask *task, void **msgbuf) { return __guard_map.create(name, task, msgbuf); } int WFTaskFactory::release_guard(const std::string& name, void *msg) { struct __guard_node *node = __guard_map.release(name); if (!node) return 0; node->guard->WFConditional::signal(msg); return 1; } int WFTaskFactory::release_guard_safe(const std::string& name, void *msg) { struct __guard_node *node = __guard_map.release(name); WFTimerTask *timer; if (!node) return 0; timer = WFTaskFactory::create_timer_task(0, 0, [](WFTimerTask *timer) { series_of(timer)->push_front((SubTask *)timer->user_data); }); timer->user_data = node->guard->get_task(); node->guard->set_task(timer); node->guard->WFConditional::signal(msg); return 1; } /**************** Timed Go Task *****************/ void __WFTimedGoTask::dispatch() { WFTimerTask *timer; timer = WFTaskFactory::create_timer_task(this->seconds, this->nanoseconds, __WFTimedGoTask::timer_callback); timer->user_data = this; this->__WFGoTask::dispatch(); timer->start(); } SubTask *__WFTimedGoTask::done() { if (this->callback) this->callback(this); return series_of(this)->pop(); } void __WFTimedGoTask::handle(int state, int error) { if (--this->ref == 3) { this->state = state; this->error = error; this->subtask_done(); } if (--this->ref == 0) delete this; } void __WFTimedGoTask::timer_callback(WFTimerTask *timer) { __WFTimedGoTask *task = (__WFTimedGoTask *)timer->user_data; if (--task->ref == 3) { if (timer->get_state() == WFT_STATE_SUCCESS) { task->state = WFT_STATE_SYS_ERROR; task->error = ETIMEDOUT; } else { task->state = timer->get_state(); task->error = timer->get_error(); } task->subtask_done(); } if (--task->ref == 0) delete task; } workflow-0.11.8/src/factory/WFTaskFactory.h000066400000000000000000000351741476003635400205730ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com) */ #ifndef _WFTASKFACTORY_H_ #define _WFTASKFACTORY_H_ #include #include #include #include #include #include #include "URIParser.h" #include "RedisMessage.h" #include "HttpMessage.h" #include "MySQLMessage.h" #include "DnsMessage.h" #include "Workflow.h" #include "WFTask.h" #include "WFGraphTask.h" #include "EndpointParams.h" // Network Client/Server tasks using WFHttpTask = WFNetworkTask; using http_callback_t = std::function; using WFRedisTask = WFNetworkTask; using redis_callback_t = std::function; using WFMySQLTask = WFNetworkTask; using mysql_callback_t = std::function; using WFDnsTask = WFNetworkTask; using dns_callback_t = std::function; // File IO tasks struct FileIOArgs { int fd; void *buf; size_t count; off_t offset; }; struct FileVIOArgs { int fd; const struct iovec *iov; int iovcnt; off_t offset; }; struct FileSyncArgs { int fd; }; using WFFileIOTask = WFFileTask; using fio_callback_t = std::function; using WFFileVIOTask = WFFileTask; using fvio_callback_t = std::function; using WFFileSyncTask = WFFileTask; using fsync_callback_t = std::function; // Timer and counter using timer_callback_t = std::function; using counter_callback_t = std::function; using mailbox_callback_t = std::function; using selector_callback_t = std::function; // Graph (DAG) task. using graph_callback_t = std::function; using WFEmptyTask = WFGenericTask; using WFDynamicTask = WFGenericTask; using dynamic_create_t = std::function; using repeated_create_t = std::function; using repeater_callback_t = std::function; using module_callback_t = std::function; class WFTaskFactory { public: static WFHttpTask *create_http_task(const std::string& url, int redirect_max, int retry_max, http_callback_t callback); static WFHttpTask *create_http_task(const ParsedURI& uri, int redirect_max, int retry_max, http_callback_t callback); static WFHttpTask *create_http_task(const std::string& url, const std::string& proxy_url, int redirect_max, int retry_max, http_callback_t callback); static WFHttpTask *create_http_task(const ParsedURI& uri, const ParsedURI& proxy_uri, int redirect_max, int retry_max, http_callback_t callback); static WFRedisTask *create_redis_task(const std::string& url, int retry_max, redis_callback_t callback); static WFRedisTask *create_redis_task(const ParsedURI& uri, int retry_max, redis_callback_t callback); static WFMySQLTask *create_mysql_task(const std::string& url, int retry_max, mysql_callback_t callback); static WFMySQLTask *create_mysql_task(const ParsedURI& uri, int retry_max, mysql_callback_t callback); static WFDnsTask *create_dns_task(const std::string& url, int retry_max, dns_callback_t callback); static WFDnsTask *create_dns_task(const ParsedURI& uri, int retry_max, dns_callback_t callback); public: static WFFileIOTask *create_pread_task(int fd, void *buf, size_t count, off_t offset, fio_callback_t callback); static WFFileIOTask *create_pwrite_task(int fd, const void *buf, size_t count, off_t offset, fio_callback_t callback); static WFFileVIOTask *create_preadv_task(int fd, const struct iovec *iov, int iovcnt, off_t offset, fvio_callback_t callback); static WFFileVIOTask *create_pwritev_task(int fd, const struct iovec *iov, int iovcnt, off_t offset, fvio_callback_t callback); static WFFileSyncTask *create_fsync_task(int fd, fsync_callback_t callback); /* On systems that do not support fdatasync(), like macOS, * fdsync task is equal to fsync task. */ static WFFileSyncTask *create_fdsync_task(int fd, fsync_callback_t callback); /* File tasks with path name. */ public: static WFFileIOTask *create_pread_task(const std::string& path, void *buf, size_t count, off_t offset, fio_callback_t callback); static WFFileIOTask *create_pwrite_task(const std::string& path, const void *buf, size_t count, off_t offset, fio_callback_t callback); static WFFileVIOTask *create_preadv_task(const std::string& path, const struct iovec *iov, int iovcnt, off_t offset, fvio_callback_t callback); static WFFileVIOTask *create_pwritev_task(const std::string& path, const struct iovec *iov, int iovcnt, off_t offset, fvio_callback_t callback); public: static WFTimerTask *create_timer_task(time_t seconds, long nanoseconds, timer_callback_t callback); /* create a named timer. */ static WFTimerTask *create_timer_task(const std::string& timer_name, time_t seconds, long nanoseconds, timer_callback_t callback); /* cancel all timers under the name. */ static int cancel_by_name(const std::string& timer_name) { return WFTaskFactory::cancel_by_name(timer_name, (size_t)-1); } /* cancel at most 'max' timers under the name. */ static int cancel_by_name(const std::string& timer_name, size_t max); /* timer in microseconds (deprecated) */ static WFTimerTask *create_timer_task(unsigned int microseconds, timer_callback_t callback); public: /* Create an unnamed counter. Call counter->count() directly. * NOTE: never call count() exceeding target_value. */ static WFCounterTask *create_counter_task(unsigned int target_value, counter_callback_t callback) { return new WFCounterTask(target_value, std::move(callback)); } /* Create a named counter. */ static WFCounterTask *create_counter_task(const std::string& counter_name, unsigned int target_value, counter_callback_t callback); /* Count by a counter's name. When count_by_name(), it's safe to count * exceeding target_value. When multiple counters share a same name, * this operation will be performed on the first created. */ static int count_by_name(const std::string& counter_name) { return WFTaskFactory::count_by_name(counter_name, 1); } /* Count by name with a value n. When multiple counters share this name, * the operation is performed on the counters in the sequence of its * creation, and more than one counter may reach target value. */ static int count_by_name(const std::string& counter_name, unsigned int n); public: static WFMailboxTask *create_mailbox_task(void **mailbox, mailbox_callback_t callback) { return new WFMailboxTask(mailbox, std::move(callback)); } /* Use 'user_data' as mailbox. */ static WFMailboxTask *create_mailbox_task(mailbox_callback_t callback) { return new WFMailboxTask(std::move(callback)); } static WFMailboxTask *create_mailbox_task(const std::string& mailbox_name, void **mailbox, mailbox_callback_t callback); static WFMailboxTask *create_mailbox_task(const std::string& mailbox_name, mailbox_callback_t callback); /* The 'msg' will be sent to the all mailbox tasks under the name, and * would be lost if no task matched. */ static int send_by_name(const std::string& mailbox_name, void *msg) { return WFTaskFactory::send_by_name(mailbox_name, msg, (size_t)-1); } static int send_by_name(const std::string& mailbox_name, void *msg, size_t max); template static int send_by_name(const std::string& mailbox_name, T *const msg[], size_t max); public: static WFSelectorTask *create_selector_task(size_t candidates, selector_callback_t callback) { return new WFSelectorTask(candidates, std::move(callback)); } public: static WFConditional *create_conditional(SubTask *task, void **msgbuf) { return new WFConditional(task, msgbuf); } static WFConditional *create_conditional(SubTask *task) { return new WFConditional(task); } static WFConditional *create_conditional(const std::string& cond_name, SubTask *task, void **msgbuf); static WFConditional *create_conditional(const std::string& cond_name, SubTask *task); static int signal_by_name(const std::string& cond_name, void *msg) { return WFTaskFactory::signal_by_name(cond_name, msg, (size_t)-1); } static int signal_by_name(const std::string& cond_name, void *msg, size_t max); template static int signal_by_name(const std::string& cond_name, T *const msg[], size_t max); public: static WFConditional *create_guard(const std::string& resource_name, SubTask *task); static WFConditional *create_guard(const std::string& resource_name, SubTask *task, void **msgbuf); /* The 'guard' is acquired after started, so call 'release_guard' after and only after the task is finished, typically in its callback. The function returns 1 if another is signaled, otherwise returns 0. */ static int release_guard(const std::string& resource_name, void *msg); static int release_guard_safe(const std::string& resource_name, void *msg); public: template static WFGoTask *create_go_task(const std::string& queue_name, FUNC&& func, ARGS&&... args); /* Create 'Go' task with running time limit in seconds plus nanoseconds. * If time exceeded, state WFT_STATE_SYS_ERROR and error ETIMEDOUT * will be got in callback. */ template static WFGoTask *create_timedgo_task(time_t seconds, long nanoseconds, const std::string& queue_name, FUNC&& func, ARGS&&... args); /* Create 'Go' task on user's executor and execution queue. */ template static WFGoTask *create_go_task(ExecQueue *queue, Executor *executor, FUNC&& func, ARGS&&... args); template static WFGoTask *create_timedgo_task(time_t seconds, long nanoseconds, ExecQueue *queue, Executor *executor, FUNC&& func, ARGS&&... args); /* For capturing 'task' itself in go task's running function. */ template static void reset_go_task(WFGoTask *task, FUNC&& func, ARGS&&... args); public: static WFGraphTask *create_graph_task(graph_callback_t callback) { return new WFGraphTask(std::move(callback)); } public: static WFEmptyTask *create_empty_task() { return new WFEmptyTask; } static WFDynamicTask *create_dynamic_task(dynamic_create_t create); static WFRepeaterTask *create_repeater_task(repeated_create_t create, repeater_callback_t callback) { return new WFRepeaterTask(std::move(create), std::move(callback)); } public: static WFModuleTask *create_module_task(SubTask *first, module_callback_t callback) { return new WFModuleTask(first, std::move(callback)); } static WFModuleTask *create_module_task(SubTask *first, SubTask *last, module_callback_t callback) { WFModuleTask *task = new WFModuleTask(first, std::move(callback)); task->sub_series()->set_last_task(last); return task; } }; template class WFNetworkTaskFactory { private: using T = WFNetworkTask; public: static T *create_client_task(enum TransportType type, const std::string& host, unsigned short port, int retry_max, std::function callback); static T *create_client_task(enum TransportType type, const std::string& url, int retry_max, std::function callback); static T *create_client_task(enum TransportType type, const ParsedURI& uri, int retry_max, std::function callback); static T *create_client_task(enum TransportType type, const struct sockaddr *addr, socklen_t addrlen, int retry_max, std::function callback); static T *create_client_task(enum TransportType type, const struct sockaddr *addr, socklen_t addrlen, SSL_CTX *ssl_ctx, int retry_max, std::function callback); public: static T *create_server_task(CommService *service, std::function& process); }; template class WFThreadTaskFactory { private: using T = WFThreadTask; public: static T *create_thread_task(const std::string& queue_name, std::function routine, std::function callback); /* Create thread task with running time limit. */ static T *create_thread_task(time_t seconds, long nanoseconds, const std::string& queue_name, std::function routine, std::function callback); public: /* Create thread task on user's executor and execution queue. */ static T *create_thread_task(ExecQueue *queue, Executor *executor, std::function routine, std::function callback); /* With running time limit. */ static T *create_thread_task(time_t seconds, long nanoseconds, ExecQueue *queue, Executor *executor, std::function routine, std::function callback); }; #include "WFTaskFactory.inl" #endif workflow-0.11.8/src/factory/WFTaskFactory.inl000066400000000000000000000547701476003635400211310ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com) Li Yingxin (liyingxin@sogou-inc.com) */ #include #include #include #include #include #include #include #include #include #include #include #include "WFGlobal.h" #include "Workflow.h" #include "WFTask.h" #include "RouteManager.h" #include "URIParser.h" #include "WFTaskError.h" #include "EndpointParams.h" #include "WFNameService.h" class __WFDynamicTask : public WFDynamicTask { protected: virtual void dispatch() { series_of(this)->push_front(this->create(this)); this->WFDynamicTask::dispatch(); } protected: std::function create; public: __WFDynamicTask(std::function&& create) : create(std::move(create)) { } }; inline WFDynamicTask * WFTaskFactory::create_dynamic_task(dynamic_create_t create) { return new __WFDynamicTask(std::move(create)); } template<> int WFTaskFactory::send_by_name(const std::string&, void *const *, size_t); template int WFTaskFactory::send_by_name(const std::string& mailbox_name, T *const msg[], size_t max) { return WFTaskFactory::send_by_name(mailbox_name, (void *const *)msg, max); } template<> int WFTaskFactory::signal_by_name(const std::string&, void *const *, size_t); template int WFTaskFactory::signal_by_name(const std::string& cond_name, T *const msg[], size_t max) { return WFTaskFactory::signal_by_name(cond_name, (void *const *)msg, max); } template class WFComplexClientTask : public WFClientTask { protected: using task_callback_t = std::function *)>; public: WFComplexClientTask(int retry_max, task_callback_t&& cb): WFClientTask(NULL, WFGlobal::get_scheduler(), std::move(cb)) { type_ = TT_TCP; ssl_ctx_ = NULL; fixed_addr_ = false; fixed_conn_ = false; retry_max_ = retry_max; retry_times_ = 0; redirect_ = false; ns_policy_ = NULL; router_task_ = NULL; } protected: // new api for children virtual bool init_success() { return true; } virtual void init_failed() {} virtual bool check_request() { return true; } virtual WFRouterTask *route(); virtual bool finish_once() { return true; } public: void init(const ParsedURI& uri) { uri_ = uri; init_with_uri(); } void init(ParsedURI&& uri) { uri_ = std::move(uri); init_with_uri(); } void init(enum TransportType type, const struct sockaddr *addr, socklen_t addrlen, const std::string& info); void set_transport_type(enum TransportType type) { type_ = type; } enum TransportType get_transport_type() const { return type_; } void set_ssl_ctx(SSL_CTX *ssl_ctx) { ssl_ctx_ = ssl_ctx; } virtual const ParsedURI *get_current_uri() const { return &uri_; } void set_redirect(const ParsedURI& uri) { redirect_ = true; init(uri); } void set_redirect(enum TransportType type, const struct sockaddr *addr, socklen_t addrlen, const std::string& info) { redirect_ = true; init(type, addr, addrlen, info); } bool is_fixed_addr() const { return this->fixed_addr_; } bool is_fixed_conn() const { return this->fixed_conn_; } protected: void set_fixed_addr(int fixed) { this->fixed_addr_ = fixed; } void set_fixed_conn(int fixed) { this->fixed_conn_ = fixed; } void set_info(const std::string& info) { info_.assign(info); } void set_info(const char *info) { info_.assign(info); } protected: virtual void dispatch(); virtual SubTask *done(); void clear_resp() { protocol::ProtocolMessage head(std::move(this->resp)); this->resp.~RESP(); new(&this->resp) RESP; *(protocol::ProtocolMessage *)&this->resp = std::move(head); } void disable_retry() { retry_times_ = retry_max_; } protected: enum TransportType type_; ParsedURI uri_; std::string info_; SSL_CTX *ssl_ctx_; bool fixed_addr_; bool fixed_conn_; bool redirect_; CTX ctx_; int retry_max_; int retry_times_; WFNSPolicy *ns_policy_; WFRouterTask *router_task_; RouteManager::RouteResult route_result_; WFNSTracing tracing_; public: CTX *get_mutable_ctx() { return &ctx_; } private: void clear_prev_state(); void init_with_uri(); bool set_port(); void router_callback(void *t); void switch_callback(void *t); }; template void WFComplexClientTask::clear_prev_state() { ns_policy_ = NULL; route_result_.clear(); if (tracing_.deleter) { tracing_.deleter(tracing_.data); tracing_.deleter = NULL; } tracing_.data = NULL; retry_times_ = 0; this->state = WFT_STATE_UNDEFINED; this->error = 0; this->timeout_reason = TOR_NOT_TIMEOUT; } template void WFComplexClientTask::init(enum TransportType type, const struct sockaddr *addr, socklen_t addrlen, const std::string& info) { if (redirect_) clear_prev_state(); auto params = WFGlobal::get_global_settings()->endpoint_params; struct addrinfo addrinfo = { }; addrinfo.ai_family = addr->sa_family; addrinfo.ai_addr = (struct sockaddr *)addr; addrinfo.ai_addrlen = addrlen; type_ = type; info_.assign(info); params.use_tls_sni = false; if (WFGlobal::get_route_manager()->get(type, &addrinfo, info_, ¶ms, "", ssl_ctx_, route_result_) < 0) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; } else if (this->init_success()) return; this->init_failed(); } template bool WFComplexClientTask::set_port() { if (uri_.port) { int port = atoi(uri_.port); if (port <= 0 || port > 65535) { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_URI_PORT_INVALID; return false; } return true; } if (uri_.scheme) { const char *port_str = WFGlobal::get_default_port(uri_.scheme); if (port_str) { uri_.port = strdup(port_str); if (uri_.port) return true; this->state = WFT_STATE_SYS_ERROR; this->error = errno; return false; } } this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_URI_SCHEME_INVALID; return false; } template void WFComplexClientTask::init_with_uri() { if (redirect_) { clear_prev_state(); ns_policy_ = WFGlobal::get_dns_resolver(); } if (uri_.state == URI_STATE_SUCCESS) { if (this->set_port()) { if (this->init_success()) return; } } else if (uri_.state == URI_STATE_ERROR) { this->state = WFT_STATE_SYS_ERROR; this->error = uri_.error; } else { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_URI_PARSE_FAILED; } this->init_failed(); } template WFRouterTask *WFComplexClientTask::route() { auto&& cb = std::bind(&WFComplexClientTask::router_callback, this, std::placeholders::_1); struct WFNSParams params = { .type = type_, .uri = uri_, .info = info_.c_str(), .ssl_ctx = ssl_ctx_, .fixed_addr = fixed_addr_, .fixed_conn = fixed_conn_, .retry_times = retry_times_, .tracing = &tracing_, }; if (!ns_policy_) { WFNameService *ns = WFGlobal::get_name_service(); ns_policy_ = ns->get_policy(uri_.host ? uri_.host : ""); } return ns_policy_->create_router_task(¶ms, std::move(cb)); } template void WFComplexClientTask::router_callback(void *t) { WFRouterTask *task = (WFRouterTask *)t; this->state = task->get_state(); if (this->state == WFT_STATE_SUCCESS) route_result_ = std::move(*task->get_result()); else if (this->state == WFT_STATE_UNDEFINED) { /* should not happend */ this->state = WFT_STATE_SYS_ERROR; this->error = ENOSYS; } else this->error = task->get_error(); } template void WFComplexClientTask::dispatch() { switch (this->state) { case WFT_STATE_UNDEFINED: if (this->check_request()) { if (this->route_result_.request_object) { case WFT_STATE_SUCCESS: this->set_request_object(route_result_.request_object); this->WFClientTask::dispatch(); return; } router_task_ = this->route(); series_of(this)->push_front(this); series_of(this)->push_front(router_task_); } default: break; } this->subtask_done(); } template void WFComplexClientTask::switch_callback(void *t) { if (!redirect_) { if (this->state == WFT_STATE_SYS_ERROR && this->error < 0) { this->state = WFT_STATE_SSL_ERROR; this->error = -this->error; } if (tracing_.deleter) { tracing_.deleter(tracing_.data); tracing_.deleter = NULL; } if (this->callback) this->callback(this); } if (redirect_) { redirect_ = false; clear_resp(); this->target = NULL; series_of(this)->push_front(this); } else delete this; } template SubTask *WFComplexClientTask::done() { SeriesWork *series = series_of(this); if (router_task_) { router_task_ = NULL; return series->pop(); } bool is_user_request = this->finish_once(); if (ns_policy_) { if (this->state == WFT_STATE_SYS_ERROR || this->state == WFT_STATE_DNS_ERROR) { ns_policy_->failed(&route_result_, &tracing_, this->target); } else if (route_result_.request_object) { ns_policy_->success(&route_result_, &tracing_, this->target); } } if (this->state == WFT_STATE_SUCCESS) { if (!is_user_request) return this; } else if (this->state == WFT_STATE_SYS_ERROR) { if (retry_times_ < retry_max_) { redirect_ = true; if (ns_policy_) route_result_.clear(); this->state = WFT_STATE_UNDEFINED; this->error = 0; this->timeout_reason = 0; retry_times_++; } } /* When the target or the connection is NULL, it's very likely that we are * in the caller's thread. Running a timer will switch callback function to * a handler thread, and this can prevent stack overflow. */ if (!this->target || !this->CommSession::get_connection()) { auto&& cb = std::bind(&WFComplexClientTask::switch_callback, this, std::placeholders::_1); WFTimerTask *timer; timer = WFTaskFactory::create_timer_task(0, 0, std::move(cb)); series->push_front(timer); } else this->switch_callback(NULL); return series->pop(); } /**********Template Network Factory**********/ template WFNetworkTask * WFNetworkTaskFactory::create_client_task(enum TransportType type, const std::string& host, unsigned short port, int retry_max, std::function *)> callback) { auto *task = new WFComplexClientTask(retry_max, std::move(callback)); ParsedURI uri; char buf[32]; sprintf(buf, "%u", port); uri.scheme = strdup("scheme"); uri.host = strdup(host.c_str()); uri.port = strdup(buf); if (!uri.scheme || !uri.host || !uri.port) { uri.state = URI_STATE_ERROR; uri.error = errno; } else uri.state = URI_STATE_SUCCESS; task->init(std::move(uri)); task->set_transport_type(type); return task; } template WFNetworkTask * WFNetworkTaskFactory::create_client_task(enum TransportType type, const std::string& url, int retry_max, std::function *)> callback) { auto *task = new WFComplexClientTask(retry_max, std::move(callback)); ParsedURI uri; URIParser::parse(url, uri); task->init(std::move(uri)); task->set_transport_type(type); return task; } template WFNetworkTask * WFNetworkTaskFactory::create_client_task(enum TransportType type, const ParsedURI& uri, int retry_max, std::function *)> callback) { auto *task = new WFComplexClientTask(retry_max, std::move(callback)); task->init(uri); task->set_transport_type(type); return task; } template WFNetworkTask * WFNetworkTaskFactory::create_client_task(enum TransportType type, const struct sockaddr *addr, socklen_t addrlen, int retry_max, std::function *)> callback) { auto *task = new WFComplexClientTask(retry_max, std::move(callback)); task->init(type, addr, addrlen, ""); return task; } template WFNetworkTask * WFNetworkTaskFactory::create_client_task(enum TransportType type, const struct sockaddr *addr, socklen_t addrlen, SSL_CTX *ssl_ctx, int retry_max, std::function *)> callback) { auto *task = new WFComplexClientTask(retry_max, std::move(callback)); task->set_ssl_ctx(ssl_ctx); task->init(type, addr, addrlen, ""); return task; } template WFNetworkTask * WFNetworkTaskFactory::create_server_task(CommService *service, std::function *)>& process) { return new WFServerTask(service, WFGlobal::get_scheduler(), process); } /**********Server Factory**********/ class WFServerTaskFactory { public: static WFDnsTask *create_dns_task(CommService *service, std::function& process); static WFHttpTask *create_http_task(CommService *service, std::function& process); static WFMySQLTask *create_mysql_task(CommService *service, std::function& process); }; /************Go Task Factory************/ class __WFGoTask : public WFGoTask { public: void set_go_func(std::function func) { this->go = std::move(func); } protected: virtual void execute() { this->go(); } protected: std::function go; public: __WFGoTask(ExecQueue *queue, Executor *executor, std::function&& func) : WFGoTask(queue, executor), go(std::move(func)) { } }; class __WFTimedGoTask : public __WFGoTask { protected: virtual void dispatch(); virtual SubTask *done(); protected: virtual void handle(int state, int error); protected: static void timer_callback(WFTimerTask *timer); protected: time_t seconds; long nanoseconds; std::atomic ref; public: __WFTimedGoTask(time_t seconds, long nanoseconds, ExecQueue *queue, Executor *executor, std::function&& func) : __WFGoTask(queue, executor, std::move(func)), ref(4) { this->seconds = seconds; this->nanoseconds = nanoseconds; } }; template WFGoTask *WFTaskFactory::create_go_task(const std::string& queue_name, FUNC&& func, ARGS&&... args) { auto&& tmp = std::bind(std::forward(func), std::forward(args)...); return new __WFGoTask(WFGlobal::get_exec_queue(queue_name), WFGlobal::get_compute_executor(), std::move(tmp)); } template WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds, const std::string& queue_name, FUNC&& func, ARGS&&... args) { auto&& tmp = std::bind(std::forward(func), std::forward(args)...); return new __WFTimedGoTask(seconds, nanoseconds, WFGlobal::get_exec_queue(queue_name), WFGlobal::get_compute_executor(), std::move(tmp)); } template WFGoTask *WFTaskFactory::create_go_task(ExecQueue *queue, Executor *executor, FUNC&& func, ARGS&&... args) { auto&& tmp = std::bind(std::forward(func), std::forward(args)...); return new __WFGoTask(queue, executor, std::move(tmp)); } template WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds, ExecQueue *queue, Executor *executor, FUNC&& func, ARGS&&... args) { auto&& tmp = std::bind(std::forward(func), std::forward(args)...); return new __WFTimedGoTask(seconds, nanoseconds, queue, executor, std::move(tmp)); } template void WFTaskFactory::reset_go_task(WFGoTask *task, FUNC&& func, ARGS&&... args) { auto&& tmp = std::bind(std::forward(func), std::forward(args)...); ((__WFGoTask *)task)->set_go_func(std::move(tmp)); } /**********Create go task with nullptr func**********/ template<> inline WFGoTask *WFTaskFactory::create_go_task(const std::string& queue_name, std::nullptr_t&&) { return new __WFGoTask(WFGlobal::get_exec_queue(queue_name), WFGlobal::get_compute_executor(), nullptr); } template<> inline WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds, const std::string& queue_name, std::nullptr_t&&) { return new __WFTimedGoTask(seconds, nanoseconds, WFGlobal::get_exec_queue(queue_name), WFGlobal::get_compute_executor(), nullptr); } template<> inline WFGoTask *WFTaskFactory::create_go_task(ExecQueue *queue, Executor *executor, std::nullptr_t&&) { return new __WFGoTask(queue, executor, nullptr); } template<> inline WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds, ExecQueue *queue, Executor *executor, std::nullptr_t&&) { return new __WFTimedGoTask(seconds, nanoseconds, queue, executor, nullptr); } template<> inline void WFTaskFactory::reset_go_task(WFGoTask *task, std::nullptr_t&&) { ((__WFGoTask *)task)->set_go_func(nullptr); } /**********Template Thread Task Factory**********/ template class __WFThreadTask : public WFThreadTask { protected: virtual void execute() { this->routine(&this->input, &this->output); } protected: std::function routine; public: __WFThreadTask(ExecQueue *queue, Executor *executor, std::function&& rt, std::function *)>&& cb) : WFThreadTask(queue, executor, std::move(cb)), routine(std::move(rt)) { } }; template class __WFTimedThreadTask : public __WFThreadTask { protected: virtual void dispatch(); virtual SubTask *done(); protected: virtual void handle(int state, int error); protected: static void timer_callback(WFTimerTask *timer); protected: time_t seconds; long nanoseconds; std::atomic ref; public: __WFTimedThreadTask(time_t seconds, long nanoseconds, ExecQueue *queue, Executor *executor, std::function&& rt, std::function *)>&& cb) : __WFThreadTask(queue, executor, std::move(rt), std::move(cb)), ref(4) { this->seconds = seconds; this->nanoseconds = nanoseconds; } }; template void __WFTimedThreadTask::dispatch() { WFTimerTask *timer; timer = WFTaskFactory::create_timer_task(this->seconds, this->nanoseconds, __WFTimedThreadTask::timer_callback); timer->user_data = this; this->__WFThreadTask::dispatch(); timer->start(); } template SubTask *__WFTimedThreadTask::done() { if (this->callback) this->callback(this); return series_of(this)->pop(); } template void __WFTimedThreadTask::handle(int state, int error) { if (--this->ref == 3) { this->state = state; this->error = error; this->subtask_done(); } if (--this->ref == 0) delete this; } template void __WFTimedThreadTask::timer_callback(WFTimerTask *timer) { auto *task = (__WFTimedThreadTask *)timer->user_data; if (--task->ref == 3) { if (timer->get_state() == WFT_STATE_SUCCESS) { task->state = WFT_STATE_SYS_ERROR; task->error = ETIMEDOUT; } else { task->state = timer->get_state(); task->error = timer->get_error(); } task->subtask_done(); } if (--task->ref == 0) delete task; } template WFThreadTask * WFThreadTaskFactory::create_thread_task(const std::string& queue_name, std::function routine, std::function *)> callback) { return new __WFThreadTask(WFGlobal::get_exec_queue(queue_name), WFGlobal::get_compute_executor(), std::move(routine), std::move(callback)); } template WFThreadTask * WFThreadTaskFactory::create_thread_task(time_t seconds, long nanoseconds, const std::string& queue_name, std::function routine, std::function *)> callback) { return new __WFTimedThreadTask(seconds, nanoseconds, WFGlobal::get_exec_queue(queue_name), WFGlobal::get_compute_executor(), std::move(routine), std::move(callback)); } template WFThreadTask * WFThreadTaskFactory::create_thread_task(ExecQueue *queue, Executor *executor, std::function routine, std::function *)> callback) { return new __WFThreadTask(queue, executor, std::move(routine), std::move(callback)); } template WFThreadTask * WFThreadTaskFactory::create_thread_task(time_t seconds, long nanoseconds, ExecQueue *queue, Executor *executor, std::function routine, std::function *)> callback) { return new __WFTimedThreadTask(seconds, nanoseconds, queue, executor, std::move(routine), std::move(callback)); } workflow-0.11.8/src/factory/Workflow.cc000066400000000000000000000114551476003635400200500ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include #include #include "Workflow.h" SeriesWork::SeriesWork(SubTask *first, series_callback_t&& cb) : callback(std::move(cb)) { this->queue = this->buf; this->queue_size = sizeof this->buf / sizeof *this->buf; this->front = 0; this->back = 0; this->canceled = false; this->finished = false; assert(!series_of(first)); first->set_pointer(this); this->first = first; this->last = NULL; this->context = NULL; this->in_parallel = NULL; } SeriesWork::~SeriesWork() { if (this->queue != this->buf) delete []this->queue; } void SeriesWork::dismiss_recursive() { SubTask *task = first; this->callback = nullptr; do { delete task; task = this->pop_task(); } while (task); } void SeriesWork::expand_queue() { int size = 2 * this->queue_size; SubTask **queue = new SubTask *[size]; int i, j; i = 0; j = this->front; do { queue[i++] = this->queue[j++]; if (j == this->queue_size) j = 0; } while (j != this->back); if (this->queue != this->buf) delete []this->queue; this->queue = queue; this->queue_size = size; this->front = 0; this->back = i; } void SeriesWork::push_front(SubTask *task) { this->mutex.lock(); if (--this->front == -1) this->front = this->queue_size - 1; task->set_pointer(this); this->queue[this->front] = task; if (this->front == this->back) this->expand_queue(); this->mutex.unlock(); } void SeriesWork::push_back(SubTask *task) { this->mutex.lock(); task->set_pointer(this); this->queue[this->back] = task; if (++this->back == this->queue_size) this->back = 0; if (this->front == this->back) this->expand_queue(); this->mutex.unlock(); } SubTask *SeriesWork::pop() { bool canceled = this->canceled; SubTask *task = this->pop_task(); if (!canceled) return task; while (task) { delete task; task = this->pop_task(); } return NULL; } SubTask *SeriesWork::pop_task() { SubTask *task; this->mutex.lock(); if (this->front != this->back) { task = this->queue[this->front]; if (++this->front == this->queue_size) this->front = 0; } else { task = this->last; this->last = NULL; } this->mutex.unlock(); if (!task) { this->finished = true; if (this->callback) this->callback(this); if (!this->in_parallel) delete this; } return task; } ParallelWork::ParallelWork(parallel_callback_t&& cb) : ParallelTask(new SubTask *[2 * 4], 0), callback(std::move(cb)) { this->buf_size = 4; this->all_series = (SeriesWork **)&this->subtasks[this->buf_size]; this->context = NULL; } ParallelWork::ParallelWork(SeriesWork *const all_series[], size_t n, parallel_callback_t&& cb) : ParallelTask(new SubTask *[2 * (n > 4 ? n : 4)], n), callback(std::move(cb)) { size_t i; this->buf_size = (n > 4 ? n : 4); this->all_series = (SeriesWork **)&this->subtasks[this->buf_size]; for (i = 0; i < n; i++) { assert(!all_series[i]->in_parallel); all_series[i]->in_parallel = this; this->all_series[i] = all_series[i]; this->subtasks[i] = all_series[i]->first; } this->context = NULL; } void ParallelWork::expand_buf() { SubTask **buf; size_t size; this->buf_size *= 2; buf = new SubTask *[2 * this->buf_size]; size = this->subtasks_nr * sizeof (void *); memcpy(buf, this->subtasks, size); memcpy(buf + this->buf_size, this->all_series, size); delete []this->subtasks; this->subtasks = buf; this->all_series = (SeriesWork **)&buf[this->buf_size]; } void ParallelWork::add_series(SeriesWork *series) { if (this->subtasks_nr == this->buf_size) this->expand_buf(); assert(!series->in_parallel); series->in_parallel = this; this->all_series[this->subtasks_nr] = series; this->subtasks[this->subtasks_nr] = series->first; this->subtasks_nr++; } SubTask *ParallelWork::done() { SeriesWork *series = series_of(this); size_t i; if (this->callback) this->callback(this); for (i = 0; i < this->subtasks_nr; i++) delete this->all_series[i]; this->subtasks_nr = 0; delete this; return series->pop(); } ParallelWork::~ParallelWork() { size_t i; for (i = 0; i < this->subtasks_nr; i++) { this->all_series[i]->in_parallel = NULL; this->all_series[i]->dismiss_recursive(); } delete []this->subtasks; } workflow-0.11.8/src/factory/Workflow.h000066400000000000000000000156171476003635400177160ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _WORKFLOW_H_ #define _WORKFLOW_H_ #include #include #include #include #include #include "SubTask.h" class SeriesWork; class ParallelWork; using series_callback_t = std::function; using parallel_callback_t = std::function; class Workflow { public: static SeriesWork * create_series_work(SubTask *first, series_callback_t callback); static void start_series_work(SubTask *first, series_callback_t callback); static ParallelWork * create_parallel_work(parallel_callback_t callback); static ParallelWork * create_parallel_work(SeriesWork *const all_series[], size_t n, parallel_callback_t callback); static void start_parallel_work(SeriesWork *const all_series[], size_t n, parallel_callback_t callback); public: static SeriesWork * create_series_work(SubTask *first, SubTask *last, series_callback_t callback); static void start_series_work(SubTask *first, SubTask *last, series_callback_t callback); }; class SeriesWork { public: void start() { assert(!this->in_parallel); this->first->dispatch(); } /* Call dismiss() only when you don't want to start a created series. * This operation is recursive, so only call on the "root". */ void dismiss() { assert(!this->in_parallel); this->dismiss_recursive(); } public: void push_back(SubTask *task); void push_front(SubTask *task); public: void *get_context() const { return this->context; } void set_context(void *context) { this->context = context; } public: /* Cancel a running series. Typically, called in the callback of a task * that belongs to the series. All subsequent tasks in the series will be * destroyed immediately and recursively (ParallelWork), without callback. * But the callback of this canceled series will still be called. */ virtual void cancel() { this->canceled = true; } /* Parallel work's callback may check the cancellation state of each * sub-series, and cancel it's super-series recursively. */ bool is_canceled() const { return this->canceled; } /* 'false' until the time of callback. Mainly for sub-class. */ bool is_finished() const { return this->finished; } public: void set_callback(series_callback_t callback) { this->callback = std::move(callback); } public: virtual void *get_specific(const char *key) { return NULL; } public: /* The following functions are intended for task implementations only. */ SubTask *pop(); SubTask *get_last_task() const { return this->last; } void set_last_task(SubTask *last) { last->set_pointer(this); this->last = last; } void unset_last_task() { this->last = NULL; } const ParallelTask *get_in_parallel() const { return this->in_parallel; } protected: void set_in_parallel(const ParallelTask *task) { this->in_parallel = task; } void dismiss_recursive(); protected: void *context; series_callback_t callback; private: SubTask *pop_task(); void expand_queue(); private: SubTask *buf[4]; SubTask *first; SubTask *last; SubTask **queue; int queue_size; int front; int back; bool canceled; bool finished; const ParallelTask *in_parallel; std::mutex mutex; protected: SeriesWork(SubTask *first, series_callback_t&& callback); virtual ~SeriesWork(); friend class ParallelWork; friend class Workflow; }; static inline SeriesWork *series_of(const SubTask *task) { return (SeriesWork *)task->get_pointer(); } static inline SeriesWork& operator *(const SubTask& task) { return *series_of(&task); } static inline SeriesWork& operator << (SeriesWork& series, SubTask *task) { series.push_back(task); return series; } inline SeriesWork * Workflow::create_series_work(SubTask *first, series_callback_t callback) { return new SeriesWork(first, std::move(callback)); } inline void Workflow::start_series_work(SubTask *first, series_callback_t callback) { new SeriesWork(first, std::move(callback)); first->dispatch(); } inline SeriesWork * Workflow::create_series_work(SubTask *first, SubTask *last, series_callback_t callback) { SeriesWork *series = new SeriesWork(first, std::move(callback)); series->set_last_task(last); return series; } inline void Workflow::start_series_work(SubTask *first, SubTask *last, series_callback_t callback) { SeriesWork *series = new SeriesWork(first, std::move(callback)); series->set_last_task(last); first->dispatch(); } class ParallelWork : public ParallelTask { public: void start() { assert(!series_of(this)); Workflow::start_series_work(this, nullptr); } void dismiss() { assert(!series_of(this)); delete this; } public: void add_series(SeriesWork *series); public: void *get_context() const { return this->context; } void set_context(void *context) { this->context = context; } public: SeriesWork *series_at(size_t index) { if (index < this->subtasks_nr) return this->all_series[index]; else return NULL; } const SeriesWork *series_at(size_t index) const { if (index < this->subtasks_nr) return this->all_series[index]; else return NULL; } SeriesWork& operator[] (size_t index) { return *this->series_at(index); } const SeriesWork& operator[] (size_t index) const { return *this->series_at(index); } size_t size() const { return this->subtasks_nr; } public: void set_callback(parallel_callback_t callback) { this->callback = std::move(callback); } protected: virtual SubTask *done(); protected: void *context; parallel_callback_t callback; private: void expand_buf(); private: size_t buf_size; SeriesWork **all_series; protected: ParallelWork(parallel_callback_t&& callback); ParallelWork(SeriesWork *const all_series[], size_t n, parallel_callback_t&& callback); virtual ~ParallelWork(); friend class Workflow; }; inline ParallelWork * Workflow::create_parallel_work(parallel_callback_t callback) { return new ParallelWork(std::move(callback)); } inline ParallelWork * Workflow::create_parallel_work(SeriesWork *const all_series[], size_t n, parallel_callback_t callback) { return new ParallelWork(all_series, n, std::move(callback)); } inline void Workflow::start_parallel_work(SeriesWork *const all_series[], size_t n, parallel_callback_t callback) { ParallelWork *p = new ParallelWork(all_series, n, std::move(callback)); Workflow::start_series_work(p, nullptr); } #endif workflow-0.11.8/src/factory/xmake.lua000066400000000000000000000010471476003635400175330ustar00rootroot00000000000000target("factory") add_files("*.cc") set_kind("object") if not has_config("mysql") then remove_files("MySQLTaskImpl.cc") end if not has_config("redis") then remove_files("RedisTaskImpl.cc") end remove_files("KafkaTaskImpl.cc") target("kafka_factory") if has_config("kafka") then add_files("KafkaTaskImpl.cc") set_kind("object") add_cxxflags("-fno-rtti") add_deps("factory") add_packages("zlib", "snappy", "zstd", "lz4") else set_kind("phony") end workflow-0.11.8/src/include/000077500000000000000000000000001476003635400156755ustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/000077500000000000000000000000001476003635400175475ustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/CommRequest.h000077700000000000000000000000001476003635400265022../../kernel/CommRequest.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/CommScheduler.h000077700000000000000000000000001476003635400272562../../kernel/CommScheduler.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/Communicator.h000077700000000000000000000000001476003635400270722../../kernel/Communicator.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/ConsulDataTypes.h000077700000000000000000000000001476003635400305172../../protocol/ConsulDataTypes.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/DnsCache.h000077700000000000000000000000001476003635400252442../../manager/DnsCache.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/DnsMessage.h000077700000000000000000000000001476003635400264352../../protocol/DnsMessage.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/DnsUtil.h000077700000000000000000000000001476003635400253172../../protocol/DnsUtil.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/EncodeStream.h000077700000000000000000000000001476003635400264112../../util/EncodeStream.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/EndpointParams.h000077700000000000000000000000001476003635400277742../../manager/EndpointParams.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/ExecRequest.h000077700000000000000000000000001476003635400264642../../kernel/ExecRequest.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/Executor.h000077700000000000000000000000001476003635400253662../../kernel/Executor.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/HttpMessage.h000077700000000000000000000000001476003635400270232../../protocol/HttpMessage.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/HttpUtil.h000077700000000000000000000000001476003635400257052../../protocol/HttpUtil.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/IORequest.h000077700000000000000000000000001476003635400255322../../kernel/IORequest.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/IOService_linux.h000077700000000000000000000000001476003635400301102../../kernel/IOService_linux.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/IOService_thread.h000077700000000000000000000000001476003635400303302../../kernel/IOService_thread.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/KafkaDataTypes.h000077700000000000000000000000001476003635400300432../../protocol/KafkaDataTypes.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/KafkaMessage.h000077700000000000000000000000001476003635400272172../../protocol/KafkaMessage.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/KafkaResult.h000077700000000000000000000000001476003635400270032../../protocol/KafkaResult.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/KafkaTaskImpl.inl000077700000000000000000000000001476003635400302132../../factory/KafkaTaskImpl.inlustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/LRUCache.h000077700000000000000000000000001476003635400244432../../util/LRUCache.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/MySQLMessage.h000077700000000000000000000000001476003635400270772../../protocol/MySQLMessage.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/MySQLMessage.inl000077700000000000000000000000001476003635400277652../../protocol/MySQLMessage.inlustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/MySQLResult.h000077700000000000000000000000001476003635400266632../../protocol/MySQLResult.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/MySQLResult.inl000077700000000000000000000000001476003635400275512../../protocol/MySQLResult.inlustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/MySQLUtil.h000077700000000000000000000000001476003635400257612../../protocol/MySQLUtil.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/PackageWrapper.h000077700000000000000000000000001476003635400301432../../protocol/PackageWrapper.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/ProtocolMessage.h000077700000000000000000000000001476003635400305672../../protocol/ProtocolMessage.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/RedisMessage.h000077700000000000000000000000001476003635400273012../../protocol/RedisMessage.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/RedisTaskImpl.inl000077700000000000000000000000001476003635400302752../../factory/RedisTaskImpl.inlustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/RouteManager.h000077700000000000000000000000001476003635400271062../../manager/RouteManager.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/SSLWrapper.h000077700000000000000000000000001476003635400263572../../protocol/SSLWrapper.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/SleepRequest.h000077700000000000000000000000001476003635400270342../../kernel/SleepRequest.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/StringUtil.h000077700000000000000000000000001476003635400256772../../util/StringUtil.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/SubTask.h000077700000000000000000000000001476003635400247222../../kernel/SubTask.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/TLVMessage.h000077700000000000000000000000001476003635400262772../../protocol/TLVMessage.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/URIParser.h000077700000000000000000000000001476003635400251172../../util/URIParser.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/UpstreamManager.h000077700000000000000000000000001476003635400303122../../manager/UpstreamManager.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/UpstreamPolicies.h000077700000000000000000000000001476003635400315732../../nameservice/UpstreamPolicies.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFAlgoTaskFactory.h000077700000000000000000000000001476003635400307452../../factory/WFAlgoTaskFactory.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFAlgoTaskFactory.inl000077700000000000000000000000001476003635400316332../../factory/WFAlgoTaskFactory.inlustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFConnection.h000077700000000000000000000000001476003635400270512../../factory/WFConnection.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFConsulClient.h000077700000000000000000000000001476003635400274662../../client/WFConsulClient.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFDnsClient.h000077700000000000000000000000001476003635400262302../../client/WFDnsClient.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFDnsResolver.h000077700000000000000000000000001476003635400302212../../nameservice/WFDnsResolver.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFDnsServer.h000077700000000000000000000000001476003635400263402../../server/WFDnsServer.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFFacilities.h000077700000000000000000000000001476003635400267462../../manager/WFFacilities.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFFacilities.inl000077700000000000000000000000001476003635400276342../../manager/WFFacilities.inlustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFFuture.h000077700000000000000000000000001476003635400253422../../manager/WFFuture.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFGlobal.h000077700000000000000000000000001476003635400252162../../manager/WFGlobal.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFGraphTask.h000077700000000000000000000000001476003635400264232../../factory/WFGraphTask.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFHttpServer.h000077700000000000000000000000001476003635400267262../../server/WFHttpServer.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFKafkaClient.h000077700000000000000000000000001476003635400270122../../client/WFKafkaClient.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFMessageQueue.h000077700000000000000000000000001476003635400276352../../factory/WFMessageQueue.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFMySQLConnection.h000077700000000000000000000000001476003635400304342../../client/WFMySQLConnection.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFMySQLServer.h000077700000000000000000000000001476003635400270022../../server/WFMySQLServer.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFNameService.h000077700000000000000000000000001476003635400301072../../nameservice/WFNameService.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFOperator.h000077700000000000000000000000001476003635400262412../../factory/WFOperator.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFRedisServer.h000077700000000000000000000000001476003635400272042../../server/WFRedisServer.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFRedisSubscriber.h000077700000000000000000000000001476003635400306462../../client/WFRedisSubscriber.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFResourcePool.h000077700000000000000000000000001476003635400277152../../factory/WFResourcePool.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFServer.h000077700000000000000000000000001476003635400252262../../server/WFServer.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFServiceGovernance.h000077700000000000000000000000001476003635400325252../../nameservice/WFServiceGovernance.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFTask.h000077700000000000000000000000001476003635400244572../../factory/WFTask.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFTask.inl000077700000000000000000000000001476003635400253452../../factory/WFTask.inlustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFTaskError.h000077700000000000000000000000001476003635400265032../../factory/WFTaskError.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFTaskFactory.h000077700000000000000000000000001476003635400273372../../factory/WFTaskFactory.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/WFTaskFactory.inl000077700000000000000000000000001476003635400302252../../factory/WFTaskFactory.inlustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/Workflow.h000077700000000000000000000000001476003635400256052../../factory/Workflow.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/crc32c.h000077700000000000000000000000001476003635400236252../../util/crc32c.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/dns_parser.h000077700000000000000000000000001476003635400266532../../protocol/dns_parser.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/http_parser.h000077700000000000000000000000001476003635400272412../../protocol/http_parser.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/json_parser.h000077700000000000000000000000001476003635400263412../../util/json_parser.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/kafka_parser.h000077700000000000000000000000001476003635400274352../../protocol/kafka_parser.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/list.h000077700000000000000000000000001476003635400237202../../kernel/list.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/mpoller.h000077700000000000000000000000001476003635400251162../../kernel/mpoller.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/msgqueue.h000077700000000000000000000000001476003635400254602../../kernel/msgqueue.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/mysql_byteorder.h000077700000000000000000000000001476003635400310232../../protocol/mysql_byteorder.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/mysql_parser.h000077700000000000000000000000001476003635400276152../../protocol/mysql_parser.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/mysql_stream.h000077700000000000000000000000001476003635400276132../../protocol/mysql_stream.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/mysql_types.h000077700000000000000000000000001476003635400273352../../protocol/mysql_types.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/poller.h000077700000000000000000000000001476003635400245642../../kernel/poller.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/rbtree.h000077700000000000000000000000001476003635400245402../../kernel/rbtree.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/redis_parser.h000077700000000000000000000000001476003635400275172../../protocol/redis_parser.hustar00rootroot00000000000000workflow-0.11.8/src/include/workflow/thrdpool.h000077700000000000000000000000001476003635400254602../../kernel/thrdpool.hustar00rootroot00000000000000workflow-0.11.8/src/kernel/000077500000000000000000000000001476003635400155325ustar00rootroot00000000000000workflow-0.11.8/src/kernel/CMakeLists.txt000066400000000000000000000007621476003635400202770ustar00rootroot00000000000000cmake_minimum_required(VERSION 3.6) project(kernel) if (CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Android") set(IOSERVICE_FILE IOService_linux.cc) elseif (UNIX) set(IOSERVICE_FILE IOService_thread.cc) else () message(FATAL_ERROR "IOService unsupported.") endif () set(SRC ${IOSERVICE_FILE} mpoller.c poller.c rbtree.c msgqueue.c thrdpool.c CommRequest.cc CommScheduler.cc Communicator.cc Executor.cc SubTask.cc ) add_library(${PROJECT_NAME} OBJECT ${SRC}) workflow-0.11.8/src/kernel/CommRequest.cc000066400000000000000000000020771476003635400203130ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include "CommScheduler.h" #include "CommRequest.h" void CommRequest::handle(int state, int error) { this->state = state; this->error = error; if (error != ETIMEDOUT) this->timeout_reason = TOR_NOT_TIMEOUT; else if (!this->target) this->timeout_reason = TOR_WAIT_TIMEOUT; else if (!this->get_message_out()) this->timeout_reason = TOR_CONNECT_TIMEOUT; else this->timeout_reason = TOR_TRANSMIT_TIMEOUT; this->subtask_done(); } workflow-0.11.8/src/kernel/CommRequest.h000066400000000000000000000034511476003635400201520ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _COMMREQUEST_H_ #define _COMMREQUEST_H_ #include #include #include "SubTask.h" #include "Communicator.h" #include "CommScheduler.h" class CommRequest : public SubTask, public CommSession { public: CommRequest(CommSchedObject *object, CommScheduler *scheduler) { this->scheduler = scheduler; this->object = object; this->wait_timeout = 0; } CommSchedObject *get_request_object() const { return this->object; } void set_request_object(CommSchedObject *object) { this->object = object; } int get_wait_timeout() const { return this->wait_timeout; } void set_wait_timeout(int timeout) { this->wait_timeout = timeout; } public: virtual void dispatch() { if (this->scheduler->request(this, this->object, this->wait_timeout, &this->target) < 0) { this->handle(CS_STATE_ERROR, errno); } } protected: int state; int error; protected: CommTarget *target; #define TOR_NOT_TIMEOUT 0 #define TOR_WAIT_TIMEOUT 1 #define TOR_CONNECT_TIMEOUT 2 #define TOR_TRANSMIT_TIMEOUT 3 int timeout_reason; protected: int wait_timeout; CommSchedObject *object; CommScheduler *scheduler; protected: virtual void handle(int state, int error); }; #endif workflow-0.11.8/src/kernel/CommScheduler.cc000066400000000000000000000211571476003635400206010ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include #include "CommScheduler.h" #define PTHREAD_COND_TIMEDWAIT(cond, mutex, abstime) \ ((abstime) ? pthread_cond_timedwait(cond, mutex, abstime) : \ pthread_cond_wait(cond, mutex)) static struct timespec *__get_abstime(int timeout, struct timespec *ts) { if (timeout < 0) return NULL; clock_gettime(CLOCK_REALTIME, ts); ts->tv_sec += timeout / 1000; ts->tv_nsec += timeout % 1000 * 1000000; if (ts->tv_nsec >= 1000000000) { ts->tv_nsec -= 1000000000; ts->tv_sec++; } return ts; } int CommSchedTarget::init(const struct sockaddr *addr, socklen_t addrlen, int connect_timeout, int response_timeout, size_t max_connections) { int ret; if (max_connections == 0) { errno = EINVAL; return -1; } if (this->CommTarget::init(addr, addrlen, connect_timeout, response_timeout) >= 0) { ret = pthread_mutex_init(&this->mutex, NULL); if (ret == 0) { ret = pthread_cond_init(&this->cond, NULL); if (ret == 0) { this->max_load = max_connections; this->cur_load = 0; this->wait_cnt = 0; this->group = NULL; return 0; } pthread_mutex_destroy(&this->mutex); } errno = ret; this->CommTarget::deinit(); } return -1; } void CommSchedTarget::deinit() { pthread_cond_destroy(&this->cond); pthread_mutex_destroy(&this->mutex); this->CommTarget::deinit(); } CommTarget *CommSchedTarget::acquire(int wait_timeout) { pthread_mutex_t *mutex = &this->mutex; int ret; pthread_mutex_lock(mutex); if (this->group) { mutex = &this->group->mutex; pthread_mutex_lock(mutex); pthread_mutex_unlock(&this->mutex); } if (this->cur_load >= this->max_load) { if (wait_timeout != 0) { struct timespec ts; struct timespec *abstime = __get_abstime(wait_timeout, &ts); do { this->wait_cnt++; ret = PTHREAD_COND_TIMEDWAIT(&this->cond, mutex, abstime); this->wait_cnt--; } while (this->cur_load >= this->max_load && ret == 0); } else ret = EAGAIN; } if (this->cur_load < this->max_load) { this->cur_load++; if (this->group) { this->group->cur_load++; this->group->heapify(this->index); } ret = 0; } pthread_mutex_unlock(mutex); if (ret) { errno = ret; return NULL; } return this; } void CommSchedTarget::release() { pthread_mutex_t *mutex = &this->mutex; pthread_mutex_lock(mutex); if (this->group) { mutex = &this->group->mutex; pthread_mutex_lock(mutex); pthread_mutex_unlock(&this->mutex); } this->cur_load--; if (this->wait_cnt > 0) pthread_cond_signal(&this->cond); if (this->group) { this->group->cur_load--; if (this->wait_cnt == 0 && this->group->wait_cnt > 0) pthread_cond_signal(&this->group->cond); this->group->heap_adjust(this->index, this->has_idle_conn()); } pthread_mutex_unlock(mutex); } int CommSchedGroup::target_cmp(CommSchedTarget *target1, CommSchedTarget *target2) { size_t load1 = target1->cur_load * target2->max_load; size_t load2 = target2->cur_load * target1->max_load; if (load1 < load2) return -1; else if (load1 > load2) return 1; else return 0; } void CommSchedGroup::heap_adjust(int index, int swap_on_equal) { CommSchedTarget *target = this->tg_heap[index]; CommSchedTarget *parent; while (index > 0) { parent = this->tg_heap[(index - 1) / 2]; if (CommSchedGroup::target_cmp(target, parent) < swap_on_equal) { this->tg_heap[index] = parent; parent->index = index; index = (index - 1) / 2; } else break; } this->tg_heap[index] = target; target->index = index; } /* Fastest heapify ever. */ void CommSchedGroup::heapify(int top) { CommSchedTarget *target = this->tg_heap[top]; int last = this->heap_size - 1; CommSchedTarget **child; int i; while (i = 2 * top + 1, i < last) { child = &this->tg_heap[i]; if (CommSchedGroup::target_cmp(child[0], target) < 0) { if (CommSchedGroup::target_cmp(child[1], child[0]) < 0) { this->tg_heap[top] = child[1]; child[1]->index = top; top = i + 1; } else { this->tg_heap[top] = child[0]; child[0]->index = top; top = i; } } else { if (CommSchedGroup::target_cmp(child[1], target) < 0) { this->tg_heap[top] = child[1]; child[1]->index = top; top = i + 1; } else { this->tg_heap[top] = target; target->index = top; return; } } } if (i == last) { child = &this->tg_heap[i]; if (CommSchedGroup::target_cmp(child[0], target) < 0) { this->tg_heap[top] = child[0]; child[0]->index = top; top = i; } } this->tg_heap[top] = target; target->index = top; } int CommSchedGroup::heap_insert(CommSchedTarget *target) { if (this->heap_size == this->heap_buf_size) { int new_size = 2 * this->heap_buf_size; void *new_base = realloc(this->tg_heap, new_size * sizeof (void *)); if (new_base) { this->tg_heap = (CommSchedTarget **)new_base; this->heap_buf_size = new_size; } else return -1; } this->tg_heap[this->heap_size] = target; target->index = this->heap_size; this->heap_adjust(this->heap_size, 0); this->heap_size++; return 0; } void CommSchedGroup::heap_remove(int index) { CommSchedTarget *target; this->heap_size--; if (index != this->heap_size) { target = this->tg_heap[this->heap_size]; this->tg_heap[index] = target; target->index = index; this->heap_adjust(index, 0); this->heapify(target->index); } } #define COMMGROUP_INIT_SIZE 4 int CommSchedGroup::init() { size_t size = COMMGROUP_INIT_SIZE * sizeof (void *); int ret; this->tg_heap = (CommSchedTarget **)malloc(size); if (this->tg_heap) { ret = pthread_mutex_init(&this->mutex, NULL); if (ret == 0) { ret = pthread_cond_init(&this->cond, NULL); if (ret == 0) { this->heap_buf_size = COMMGROUP_INIT_SIZE; this->heap_size = 0; this->max_load = 0; this->cur_load = 0; this->wait_cnt = 0; return 0; } pthread_mutex_destroy(&this->mutex); } errno = ret; free(this->tg_heap); } return -1; } void CommSchedGroup::deinit() { pthread_cond_destroy(&this->cond); pthread_mutex_destroy(&this->mutex); free(this->tg_heap); } int CommSchedGroup::add(CommSchedTarget *target) { int ret = -1; pthread_mutex_lock(&target->mutex); pthread_mutex_lock(&this->mutex); if (target->group == NULL && target->wait_cnt == 0) { if (this->heap_insert(target) >= 0) { target->group = this; this->max_load += target->max_load; this->cur_load += target->cur_load; if (this->wait_cnt > 0 && this->cur_load < this->max_load) pthread_cond_signal(&this->cond); ret = 0; } } else if (target->group == this) errno = EEXIST; else if (target->group) errno = EINVAL; else errno = EBUSY; pthread_mutex_unlock(&this->mutex); pthread_mutex_unlock(&target->mutex); return ret; } int CommSchedGroup::remove(CommSchedTarget *target) { int ret = -1; pthread_mutex_lock(&target->mutex); pthread_mutex_lock(&this->mutex); if (target->group == this && target->wait_cnt == 0) { this->heap_remove(target->index); this->max_load -= target->max_load; this->cur_load -= target->cur_load; target->group = NULL; ret = 0; } else if (target->group != this) errno = ENOENT; else errno = EBUSY; pthread_mutex_unlock(&this->mutex); pthread_mutex_unlock(&target->mutex); return ret; } CommTarget *CommSchedGroup::acquire(int wait_timeout) { pthread_mutex_t *mutex = &this->mutex; CommSchedTarget *target; int ret; pthread_mutex_lock(mutex); if (this->cur_load >= this->max_load) { if (wait_timeout != 0) { struct timespec ts; struct timespec *abstime = __get_abstime(wait_timeout, &ts); do { this->wait_cnt++; ret = PTHREAD_COND_TIMEDWAIT(&this->cond, mutex, abstime); this->wait_cnt--; } while (this->cur_load >= this->max_load && ret == 0); } else ret = EAGAIN; } if (this->cur_load < this->max_load) { target = this->tg_heap[0]; target->cur_load++; this->cur_load++; this->heapify(0); ret = 0; } pthread_mutex_unlock(mutex); if (ret) { errno = ret; return NULL; } return target; } workflow-0.11.8/src/kernel/CommScheduler.h000066400000000000000000000104201476003635400204320ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _COMMSCHEDULER_H_ #define _COMMSCHEDULER_H_ #include #include #include #include #include "Communicator.h" class CommSchedObject { public: size_t get_max_load() const { return this->max_load; } size_t get_cur_load() const { return this->cur_load; } private: virtual CommTarget *acquire(int wait_timeout) = 0; protected: size_t max_load; size_t cur_load; public: virtual ~CommSchedObject() { } friend class CommScheduler; }; class CommSchedGroup; class CommSchedTarget : public CommSchedObject, public CommTarget { public: int init(const struct sockaddr *addr, socklen_t addrlen, int connect_timeout, int response_timeout, size_t max_connections); void deinit(); public: int init(const struct sockaddr *addr, socklen_t addrlen, SSL_CTX *ssl_ctx, int connect_timeout, int ssl_connect_timeout, int response_timeout, size_t max_connections) { int ret = this->init(addr, addrlen, connect_timeout, response_timeout, max_connections); if (ret >= 0) this->set_ssl(ssl_ctx, ssl_connect_timeout); return ret; } private: virtual CommTarget *acquire(int wait_timeout); /* final */ virtual void release(); /* final */ private: CommSchedGroup *group; int index; int wait_cnt; pthread_mutex_t mutex; pthread_cond_t cond; friend class CommSchedGroup; }; class CommSchedGroup : public CommSchedObject { public: int init(); void deinit(); int add(CommSchedTarget *target); int remove(CommSchedTarget *target); private: virtual CommTarget *acquire(int wait_timeout); /* final */ private: CommSchedTarget **tg_heap; int heap_size; int heap_buf_size; int wait_cnt; pthread_mutex_t mutex; pthread_cond_t cond; private: static int target_cmp(CommSchedTarget *target1, CommSchedTarget *target2); void heapify(int top); void heap_adjust(int index, int swap_on_equal); int heap_insert(CommSchedTarget *target); void heap_remove(int index); friend class CommSchedTarget; }; class CommScheduler { public: int init(size_t poller_threads, size_t handler_threads) { return this->comm.init(poller_threads, handler_threads); } void deinit() { this->comm.deinit(); } /* wait_timeout in milliseconds, -1 for no timeout. */ int request(CommSession *session, CommSchedObject *object, int wait_timeout, CommTarget **target) { int ret = -1; *target = object->acquire(wait_timeout); if (*target) { ret = this->comm.request(session, *target); if (ret < 0) (*target)->release(); } return ret; } /* for services. */ int reply(CommSession *session) { return this->comm.reply(session); } int shutdown(CommSession *session) { return this->comm.shutdown(session); } int push(const void *buf, size_t size, CommSession *session) { return this->comm.push(buf, size, session); } int bind(CommService *service) { return this->comm.bind(service); } void unbind(CommService *service) { this->comm.unbind(service); } /* for sleepers. */ int sleep(SleepSession *session) { return this->comm.sleep(session); } /* Call 'unsleep' only before 'handle()' returns. */ int unsleep(SleepSession *session) { return this->comm.unsleep(session); } /* for file aio services. */ int io_bind(IOService *service) { return this->comm.io_bind(service); } void io_unbind(IOService *service) { this->comm.io_unbind(service); } public: int is_handler_thread() const { return this->comm.is_handler_thread(); } int increase_handler_thread() { return this->comm.increase_handler_thread(); } int decrease_handler_thread() { return this->comm.decrease_handler_thread(); } private: Communicator comm; public: virtual ~CommScheduler() { } }; #endif workflow-0.11.8/src/kernel/Communicator.cc000066400000000000000000001271551476003635400205140ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "list.h" #include "msgqueue.h" #include "thrdpool.h" #include "poller.h" #include "mpoller.h" #include "Communicator.h" struct CommConnEntry { struct list_head list; CommConnection *conn; long long seq; int sockfd; #define CONN_STATE_CONNECTING 0 #define CONN_STATE_CONNECTED 1 #define CONN_STATE_RECEIVING 2 #define CONN_STATE_SUCCESS 3 #define CONN_STATE_IDLE 4 #define CONN_STATE_KEEPALIVE 5 #define CONN_STATE_CLOSING 6 #define CONN_STATE_ERROR 7 int state; int error; int ref; struct iovec *write_iov; SSL *ssl; CommSession *session; CommTarget *target; CommService *service; mpoller_t *mpoller; /* Connection entry's mutex is for client session only. */ pthread_mutex_t mutex; }; static inline int __set_fd_nonblock(int fd) { int flags = fcntl(fd, F_GETFL); if (flags >= 0) flags = fcntl(fd, F_SETFL, flags | O_NONBLOCK); return flags; } static int __bind_sockaddr(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { struct sockaddr_storage ss; socklen_t len; len = sizeof (struct sockaddr_storage); if (getsockname(sockfd, (struct sockaddr *)&ss, &len) < 0) return -1; ss.ss_family = 0; while (len != 0) { if (((char *)&ss)[--len] != 0) break; } if (len == 0) { if (bind(sockfd, addr, addrlen) < 0) return -1; } return 0; } static int __create_ssl(SSL_CTX *ssl_ctx, struct CommConnEntry *entry) { BIO *bio = BIO_new_socket(entry->sockfd, BIO_NOCLOSE); if (bio) { entry->ssl = SSL_new(ssl_ctx); if (entry->ssl) { SSL_set_bio(entry->ssl, bio, bio); return 0; } BIO_free(bio); } return -1; } static void __release_conn(struct CommConnEntry *entry) { delete entry->conn; if (!entry->service) pthread_mutex_destroy(&entry->mutex); if (entry->ssl) SSL_free(entry->ssl); close(entry->sockfd); free(entry); } int CommTarget::init(const struct sockaddr *addr, socklen_t addrlen, int connect_timeout, int response_timeout) { int ret; this->addr = (struct sockaddr *)malloc(addrlen); if (this->addr) { ret = pthread_mutex_init(&this->mutex, NULL); if (ret == 0) { memcpy(this->addr, addr, addrlen); this->addrlen = addrlen; this->connect_timeout = connect_timeout; this->response_timeout = response_timeout; INIT_LIST_HEAD(&this->idle_list); this->ssl_ctx = NULL; this->ssl_connect_timeout = 0; return 0; } errno = ret; free(this->addr); } return -1; } void CommTarget::deinit() { pthread_mutex_destroy(&this->mutex); free(this->addr); } int CommMessageIn::feedback(const void *buf, size_t size) { struct CommConnEntry *entry = this->entry; const struct sockaddr *addr; socklen_t addrlen; int ret; if (!entry->ssl) { if (entry->service) { entry->target->get_addr(&addr, &addrlen); return sendto(entry->sockfd, buf, size, 0, addr, addrlen); } else return write(entry->sockfd, buf, size); } if (size == 0) return 0; ret = SSL_write(entry->ssl, buf, size); if (ret <= 0) { ret = SSL_get_error(entry->ssl, ret); if (ret != SSL_ERROR_SYSCALL) errno = -ret; ret = -1; } return ret; } void CommMessageIn::renew() { CommSession *session = this->entry->session; session->timeout = -1; session->begin_time.tv_sec = -1; session->begin_time.tv_nsec = -1; } int CommService::init(const struct sockaddr *bind_addr, socklen_t addrlen, int listen_timeout, int response_timeout) { int ret; this->bind_addr = (struct sockaddr *)malloc(addrlen); if (this->bind_addr) { ret = pthread_mutex_init(&this->mutex, NULL); if (ret == 0) { memcpy(this->bind_addr, bind_addr, addrlen); this->addrlen = addrlen; this->listen_timeout = listen_timeout; this->response_timeout = response_timeout; INIT_LIST_HEAD(&this->keep_alive_list); this->ssl_ctx = NULL; this->ssl_accept_timeout = 0; return 0; } errno = ret; free(this->bind_addr); } return -1; } void CommService::deinit() { pthread_mutex_destroy(&this->mutex); free(this->bind_addr); } int CommService::drain(int max) { struct CommConnEntry *entry; struct list_head *pos; int errno_bak; int cnt = 0; errno_bak = errno; pthread_mutex_lock(&this->mutex); while (cnt != max && !list_empty(&this->keep_alive_list)) { pos = this->keep_alive_list.prev; entry = list_entry(pos, struct CommConnEntry, list); list_del(pos); cnt++; /* Cannot change the sequence of next two lines. */ mpoller_del(entry->sockfd, entry->mpoller); entry->state = CONN_STATE_CLOSING; } pthread_mutex_unlock(&this->mutex); errno = errno_bak; return cnt; } inline void CommService::incref() { __sync_add_and_fetch(&this->ref, 1); } inline void CommService::decref() { if (__sync_sub_and_fetch(&this->ref, 1) == 0) this->handle_unbound(); } class CommServiceTarget : public CommTarget { public: void incref() { __sync_add_and_fetch(&this->ref, 1); } void decref() { if (__sync_sub_and_fetch(&this->ref, 1) == 0) { this->service->decref(); this->deinit(); delete this; } } public: int shutdown(); private: int sockfd; int ref; private: CommService *service; private: virtual ~CommServiceTarget() { } friend class Communicator; }; int CommServiceTarget::shutdown() { struct CommConnEntry *entry; int errno_bak; int ret = 0; pthread_mutex_lock(&this->mutex); if (!list_empty(&this->idle_list)) { entry = list_entry(this->idle_list.next, struct CommConnEntry, list); list_del(&entry->list); if (this->service->reliable) { errno_bak = errno; mpoller_del(entry->sockfd, entry->mpoller); entry->state = CONN_STATE_CLOSING; errno = errno_bak; } else { __release_conn(entry); this->decref(); } ret = 1; } pthread_mutex_unlock(&this->mutex); return ret; } CommSession::~CommSession() { CommServiceTarget *target; if (!this->passive) return; target = (CommServiceTarget *)this->target; if (!this->out && target->has_idle_conn()) target->shutdown(); target->decref(); } inline int Communicator::first_timeout(CommSession *session) { int timeout = session->target->response_timeout; if (timeout < 0 || (unsigned int)session->timeout <= (unsigned int)timeout) { timeout = session->timeout; session->timeout = 0; } else clock_gettime(CLOCK_MONOTONIC, &session->begin_time); return timeout; } int Communicator::next_timeout(CommSession *session) { int timeout = session->target->response_timeout; struct timespec cur_time; int time_used, time_left; if (session->timeout > 0) { clock_gettime(CLOCK_MONOTONIC, &cur_time); time_used = 1000 * (cur_time.tv_sec - session->begin_time.tv_sec) + (cur_time.tv_nsec - session->begin_time.tv_nsec) / 1000000; time_left = session->timeout - time_used; if (time_left <= timeout) /* here timeout >= 0 */ { timeout = time_left < 0 ? 0 : time_left; session->timeout = 0; } } return timeout; } int Communicator::first_timeout_send(CommSession *session) { session->timeout = session->send_timeout(); return Communicator::first_timeout(session); } int Communicator::first_timeout_recv(CommSession *session) { session->timeout = session->receive_timeout(); return Communicator::first_timeout(session); } void Communicator::shutdown_service(CommService *service) { close(service->listen_fd); service->listen_fd = -1; service->drain(-1); service->decref(); } #ifndef IOV_MAX # ifdef UIO_MAXIOV # define IOV_MAX UIO_MAXIOV # else # define IOV_MAX 1024 # endif #endif int Communicator::send_message_sync(struct iovec vectors[], int cnt, struct CommConnEntry *entry) { CommSession *session = entry->session; CommService *service; int timeout; ssize_t n; int i; while (cnt > 0) { if (!entry->ssl) { n = writev(entry->sockfd, vectors, cnt <= IOV_MAX ? cnt : IOV_MAX); if (n < 0) return errno == EAGAIN ? cnt : -1; } else if (vectors->iov_len > 0) { n = SSL_write(entry->ssl, vectors->iov_base, vectors->iov_len); if (n <= 0) return cnt; } else n = 0; for (i = 0; i < cnt; i++) { if ((size_t)n >= vectors[i].iov_len) n -= vectors[i].iov_len; else { vectors[i].iov_base = (char *)vectors[i].iov_base + n; vectors[i].iov_len -= n; break; } } vectors += i; cnt -= i; } service = entry->service; if (service) { __sync_add_and_fetch(&entry->ref, 1); timeout = session->keep_alive_timeout(); switch (timeout) { default: mpoller_set_timeout(entry->sockfd, timeout, this->mpoller); pthread_mutex_lock(&service->mutex); if (service->listen_fd >= 0) { entry->state = CONN_STATE_KEEPALIVE; list_add(&entry->list, &service->keep_alive_list); entry = NULL; } pthread_mutex_unlock(&service->mutex); if (entry) { case 0: mpoller_del(entry->sockfd, this->mpoller); entry->state = CONN_STATE_CLOSING; } } } else { if (entry->state == CONN_STATE_IDLE) { timeout = session->first_timeout(); if (timeout == 0) timeout = Communicator::first_timeout_recv(session); else { session->timeout = -1; session->begin_time.tv_sec = -1; session->begin_time.tv_nsec = 0; } mpoller_set_timeout(entry->sockfd, timeout, this->mpoller); } entry->state = CONN_STATE_RECEIVING; } return 0; } int Communicator::send_message_async(struct iovec vectors[], int cnt, struct CommConnEntry *entry) { struct poller_data data; int timeout; int ret; int i; entry->write_iov = (struct iovec *)malloc(cnt * sizeof (struct iovec)); if (entry->write_iov) { for (i = 0; i < cnt; i++) entry->write_iov[i] = vectors[i]; } else return -1; data.operation = PD_OP_WRITE; data.fd = entry->sockfd; data.ssl = entry->ssl; data.partial_written = Communicator::partial_written; data.context = entry; data.write_iov = entry->write_iov; data.iovcnt = cnt; timeout = Communicator::first_timeout_send(entry->session); if (entry->state == CONN_STATE_IDLE) { ret = mpoller_mod(&data, timeout, this->mpoller); if (ret < 0 && errno == ENOENT) entry->state = CONN_STATE_RECEIVING; } else { ret = mpoller_add(&data, timeout, this->mpoller); if (ret >= 0) { if (this->stop_flag) mpoller_del(data.fd, this->mpoller); } } if (ret < 0) { free(entry->write_iov); if (entry->state != CONN_STATE_RECEIVING) return -1; } return 1; } #define ENCODE_IOV_MAX 2048 int Communicator::send_message(struct CommConnEntry *entry) { struct iovec vectors[ENCODE_IOV_MAX]; struct iovec *end; int cnt; cnt = entry->session->out->encode(vectors, ENCODE_IOV_MAX); if ((unsigned int)cnt > ENCODE_IOV_MAX) { if (cnt > ENCODE_IOV_MAX) errno = EOVERFLOW; return -1; } end = vectors + cnt; cnt = this->send_message_sync(vectors, cnt, entry); if (cnt <= 0) return cnt; return this->send_message_async(end - cnt, cnt, entry); } void Communicator::handle_incoming_request(struct poller_result *res) { struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; CommTarget *target = entry->target; CommSession *session = NULL; int state; switch (res->state) { case PR_ST_SUCCESS: session = entry->session; state = CS_STATE_TOREPLY; pthread_mutex_lock(&target->mutex); if (entry->state == CONN_STATE_SUCCESS) { __sync_add_and_fetch(&entry->ref, 1); entry->state = CONN_STATE_IDLE; list_add(&entry->list, &target->idle_list); } pthread_mutex_unlock(&target->mutex); break; case PR_ST_FINISHED: res->error = ECONNRESET; if (1) case PR_ST_ERROR: state = CS_STATE_ERROR; else case PR_ST_DELETED: case PR_ST_STOPPED: state = CS_STATE_STOPPED; pthread_mutex_lock(&target->mutex); switch (entry->state) { case CONN_STATE_KEEPALIVE: pthread_mutex_lock(&entry->service->mutex); if (entry->state == CONN_STATE_KEEPALIVE) list_del(&entry->list); pthread_mutex_unlock(&entry->service->mutex); break; case CONN_STATE_IDLE: list_del(&entry->list); break; case CONN_STATE_ERROR: res->error = entry->error; state = CS_STATE_ERROR; case CONN_STATE_RECEIVING: session = entry->session; break; case CONN_STATE_SUCCESS: /* This may happen only if handler_threads > 1. */ entry->state = CONN_STATE_CLOSING; entry = NULL; break; } pthread_mutex_unlock(&target->mutex); break; } if (entry) { if (session) session->handle(state, res->error); if (__sync_sub_and_fetch(&entry->ref, 1) == 0) { __release_conn(entry); ((CommServiceTarget *)target)->decref(); } } } void Communicator::handle_incoming_reply(struct poller_result *res) { struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; CommTarget *target = entry->target; CommSession *session = NULL; pthread_mutex_t *mutex; int state; switch (res->state) { case PR_ST_SUCCESS: session = entry->session; state = CS_STATE_SUCCESS; pthread_mutex_lock(&target->mutex); if (entry->state == CONN_STATE_SUCCESS) { __sync_add_and_fetch(&entry->ref, 1); if (session->timeout != 0) /* This is keep-alive timeout. */ { entry->state = CONN_STATE_IDLE; list_add(&entry->list, &target->idle_list); } else entry->state = CONN_STATE_CLOSING; } pthread_mutex_unlock(&target->mutex); break; case PR_ST_FINISHED: res->error = ECONNRESET; if (1) case PR_ST_ERROR: state = CS_STATE_ERROR; else case PR_ST_DELETED: case PR_ST_STOPPED: state = CS_STATE_STOPPED; mutex = &entry->mutex; pthread_mutex_lock(&target->mutex); pthread_mutex_lock(mutex); switch (entry->state) { case CONN_STATE_IDLE: list_del(&entry->list); break; case CONN_STATE_ERROR: res->error = entry->error; state = CS_STATE_ERROR; case CONN_STATE_RECEIVING: session = entry->session; break; case CONN_STATE_SUCCESS: /* This may happen only if handler_threads > 1. */ entry->state = CONN_STATE_CLOSING; entry = NULL; break; } pthread_mutex_unlock(&target->mutex); pthread_mutex_unlock(mutex); break; } if (entry) { if (session) { target->release(); session->handle(state, res->error); } if (__sync_sub_and_fetch(&entry->ref, 1) == 0) __release_conn(entry); } } void Communicator::handle_read_result(struct poller_result *res) { struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; if (res->state != PR_ST_MODIFIED) { if (entry->service) this->handle_incoming_request(res); else this->handle_incoming_reply(res); } } void Communicator::handle_reply_result(struct poller_result *res) { struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; CommService *service = entry->service; CommSession *session = entry->session; CommTarget *target = entry->target; int timeout; int state; switch (res->state) { case PR_ST_FINISHED: timeout = session->keep_alive_timeout(); if (timeout != 0) { __sync_add_and_fetch(&entry->ref, 1); res->data.operation = PD_OP_READ; res->data.create_message = Communicator::create_request; res->data.message = NULL; pthread_mutex_lock(&target->mutex); if (mpoller_add(&res->data, timeout, this->mpoller) >= 0) { pthread_mutex_lock(&service->mutex); if (!this->stop_flag && service->listen_fd >= 0) { entry->state = CONN_STATE_KEEPALIVE; list_add(&entry->list, &service->keep_alive_list); } else { mpoller_del(res->data.fd, this->mpoller); entry->state = CONN_STATE_CLOSING; } pthread_mutex_unlock(&service->mutex); } else __sync_sub_and_fetch(&entry->ref, 1); pthread_mutex_unlock(&target->mutex); } if (1) state = CS_STATE_SUCCESS; else if (1) case PR_ST_ERROR: state = CS_STATE_ERROR; else case PR_ST_DELETED: /* DELETED seems not possible. */ case PR_ST_STOPPED: state = CS_STATE_STOPPED; session->handle(state, res->error); if (__sync_sub_and_fetch(&entry->ref, 1) == 0) { __release_conn(entry); ((CommServiceTarget *)target)->decref(); } break; } } void Communicator::handle_request_result(struct poller_result *res) { struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; CommSession *session = entry->session; int timeout; int state; switch (res->state) { case PR_ST_FINISHED: entry->state = CONN_STATE_RECEIVING; res->data.operation = PD_OP_READ; res->data.create_message = Communicator::create_reply; res->data.message = NULL; timeout = session->first_timeout(); if (timeout == 0) timeout = Communicator::first_timeout_recv(session); else { session->timeout = -1; session->begin_time.tv_sec = -1; session->begin_time.tv_nsec = 0; } if (mpoller_add(&res->data, timeout, this->mpoller) >= 0) { if (this->stop_flag) mpoller_del(res->data.fd, this->mpoller); break; } res->error = errno; if (1) case PR_ST_ERROR: state = CS_STATE_ERROR; else case PR_ST_DELETED: case PR_ST_STOPPED: state = CS_STATE_STOPPED; entry->target->release(); session->handle(state, res->error); pthread_mutex_lock(&entry->mutex); /* do nothing */ pthread_mutex_unlock(&entry->mutex); if (__sync_sub_and_fetch(&entry->ref, 1) == 0) __release_conn(entry); break; } } void Communicator::handle_write_result(struct poller_result *res) { struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; free(entry->write_iov); if (entry->service) this->handle_reply_result(res); else this->handle_request_result(res); } struct CommConnEntry *Communicator::accept_conn(CommServiceTarget *target, CommService *service) { struct CommConnEntry *entry; size_t size; if (__set_fd_nonblock(target->sockfd) >= 0) { size = offsetof(struct CommConnEntry, mutex); entry = (struct CommConnEntry *)malloc(size); if (entry) { entry->conn = service->new_connection(target->sockfd); if (entry->conn) { entry->seq = 0; entry->mpoller = NULL; entry->service = service; entry->target = target; entry->ssl = NULL; entry->sockfd = target->sockfd; entry->state = CONN_STATE_CONNECTED; entry->ref = 1; return entry; } free(entry); } } return NULL; } void Communicator::handle_connect_result(struct poller_result *res) { struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; CommSession *session = entry->session; CommTarget *target = entry->target; int timeout; int state; int ret; switch (res->state) { case PR_ST_FINISHED: if (target->ssl_ctx && !entry->ssl) { if (__create_ssl(target->ssl_ctx, entry) >= 0 && target->init_ssl(entry->ssl) >= 0) { ret = 0; res->data.operation = PD_OP_SSL_CONNECT; res->data.ssl = entry->ssl; timeout = target->ssl_connect_timeout; } else ret = -1; } else if ((session->out = session->message_out()) != NULL) { ret = this->send_message(entry); if (ret == 0) { res->data.operation = PD_OP_READ; res->data.create_message = Communicator::create_reply; res->data.message = NULL; timeout = session->first_timeout(); if (timeout == 0) timeout = Communicator::first_timeout_recv(session); else { session->timeout = -1; session->begin_time.tv_sec = -1; session->begin_time.tv_nsec = 0; } } else if (ret > 0) break; } else ret = -1; if (ret >= 0) { if (mpoller_add(&res->data, timeout, this->mpoller) >= 0) { if (this->stop_flag) mpoller_del(res->data.fd, this->mpoller); break; } } res->error = errno; if (1) case PR_ST_ERROR: state = CS_STATE_ERROR; else case PR_ST_DELETED: case PR_ST_STOPPED: state = CS_STATE_STOPPED; target->release(); session->handle(state, res->error); __release_conn(entry); break; } } void Communicator::handle_listen_result(struct poller_result *res) { CommService *service = (CommService *)res->data.context; struct CommConnEntry *entry; CommServiceTarget *target; int timeout; switch (res->state) { case PR_ST_SUCCESS: target = (CommServiceTarget *)res->data.result; entry = Communicator::accept_conn(target, service); if (entry) { entry->mpoller = this->mpoller; if (service->ssl_ctx) { if (__create_ssl(service->ssl_ctx, entry) >= 0 && service->init_ssl(entry->ssl) >= 0) { res->data.operation = PD_OP_SSL_ACCEPT; timeout = service->ssl_accept_timeout; } } else { res->data.operation = PD_OP_READ; res->data.create_message = Communicator::create_request; res->data.message = NULL; timeout = target->response_timeout; } if (res->data.operation != PD_OP_LISTEN) { res->data.fd = entry->sockfd; res->data.ssl = entry->ssl; res->data.context = entry; if (mpoller_add(&res->data, timeout, this->mpoller) >= 0) { if (this->stop_flag) mpoller_del(res->data.fd, this->mpoller); break; } } __release_conn(entry); } else close(target->sockfd); target->decref(); break; case PR_ST_DELETED: this->shutdown_service(service); break; case PR_ST_ERROR: case PR_ST_STOPPED: service->handle_stop(res->error); break; } } void Communicator::handle_recvfrom_result(struct poller_result *res) { CommService *service = (CommService *)res->data.context; struct CommConnEntry *entry; CommSession *session; CommTarget *target; int state, error; switch (res->state) { case PR_ST_SUCCESS: entry = (struct CommConnEntry *)res->data.result; session = entry->session; target = entry->target; if (entry->state == CONN_STATE_SUCCESS) { state = CS_STATE_TOREPLY; error = 0; entry->state = CONN_STATE_IDLE; list_add(&entry->list, &target->idle_list); } else { state = CS_STATE_ERROR; if (entry->state == CONN_STATE_ERROR) error = entry->error; else error = EBADMSG; } session->handle(state, error); if (state == CS_STATE_ERROR) { __release_conn(entry); ((CommServiceTarget *)target)->decref(); } break; case PR_ST_DELETED: this->shutdown_service(service); break; case PR_ST_ERROR: case PR_ST_STOPPED: service->handle_stop(res->error); break; } } void Communicator::handle_ssl_accept_result(struct poller_result *res) { struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; CommTarget *target = entry->target; int timeout; switch (res->state) { case PR_ST_FINISHED: res->data.operation = PD_OP_READ; res->data.create_message = Communicator::create_request; res->data.message = NULL; timeout = target->response_timeout; if (mpoller_add(&res->data, timeout, this->mpoller) >= 0) { if (this->stop_flag) mpoller_del(res->data.fd, this->mpoller); break; } case PR_ST_DELETED: case PR_ST_ERROR: case PR_ST_STOPPED: __release_conn(entry); ((CommServiceTarget *)target)->decref(); break; } } void Communicator::handle_sleep_result(struct poller_result *res) { SleepSession *session = (SleepSession *)res->data.context; int state; switch (res->state) { case PR_ST_FINISHED: state = SS_STATE_COMPLETE; break; case PR_ST_DELETED: res->error = ECANCELED; case PR_ST_ERROR: state = SS_STATE_ERROR; break; case PR_ST_STOPPED: state = SS_STATE_DISRUPTED; break; } session->handle(state, res->error); } void Communicator::handle_aio_result(struct poller_result *res) { IOService *service = (IOService *)res->data.context; IOSession *session; int state, error; switch (res->state) { case PR_ST_SUCCESS: session = (IOSession *)res->data.result; pthread_mutex_lock(&service->mutex); list_del(&session->list); pthread_mutex_unlock(&service->mutex); if (session->res >= 0) { state = IOS_STATE_SUCCESS; error = 0; } else { state = IOS_STATE_ERROR; error = -session->res; } session->handle(state, error); service->decref(); break; case PR_ST_DELETED: this->shutdown_io_service(service); break; case PR_ST_ERROR: case PR_ST_STOPPED: service->handle_stop(res->error); break; } } void Communicator::handler_thread_routine(void *context) { Communicator *comm = (Communicator *)context; struct poller_result *res; while (1) { res = (struct poller_result *)msgqueue_get(comm->msgqueue); if (!res) break; switch (res->data.operation) { case PD_OP_TIMER: comm->handle_sleep_result(res); break; case PD_OP_READ: comm->handle_read_result(res); break; case PD_OP_WRITE: comm->handle_write_result(res); break; case PD_OP_CONNECT: case PD_OP_SSL_CONNECT: comm->handle_connect_result(res); break; case PD_OP_LISTEN: comm->handle_listen_result(res); break; case PD_OP_RECVFROM: comm->handle_recvfrom_result(res); break; case PD_OP_SSL_ACCEPT: comm->handle_ssl_accept_result(res); break; case PD_OP_EVENT: case PD_OP_NOTIFY: comm->handle_aio_result(res); break; default: free(res); thrdpool_exit(comm->thrdpool); return; } free(res); } } int Communicator::append_message(const void *buf, size_t *size, poller_message_t *msg) { CommMessageIn *in = (CommMessageIn *)msg; struct CommConnEntry *entry = in->entry; CommSession *session = entry->session; int timeout; int ret; ret = in->append(buf, size); if (ret > 0) { entry->state = CONN_STATE_SUCCESS; if (!entry->service) { timeout = session->keep_alive_timeout(); session->timeout = timeout; /* Reuse session's timeout field. */ if (timeout == 0) { mpoller_del(entry->sockfd, entry->mpoller); return ret; } } else timeout = -1; } else if (ret == 0 && session->timeout != 0) { if (session->begin_time.tv_sec < 0) { if (session->begin_time.tv_nsec < 0) timeout = session->first_timeout(); else timeout = 0; if (timeout == 0) timeout = Communicator::first_timeout_recv(session); else session->begin_time.tv_nsec = 0; } else timeout = Communicator::next_timeout(session); } else return ret; /* This set_timeout() never fails, which is very important. */ mpoller_set_timeout(entry->sockfd, timeout, entry->mpoller); return ret; } poller_message_t *Communicator::create_request(void *context) { struct CommConnEntry *entry = (struct CommConnEntry *)context; CommService *service = entry->service; CommTarget *target = entry->target; CommSession *session; CommMessageIn *in; int timeout; if (entry->state == CONN_STATE_IDLE) { pthread_mutex_lock(&target->mutex); /* do nothing */ pthread_mutex_unlock(&target->mutex); } pthread_mutex_lock(&service->mutex); if (entry->state == CONN_STATE_KEEPALIVE) list_del(&entry->list); else if (entry->state != CONN_STATE_CONNECTED) entry = NULL; pthread_mutex_unlock(&service->mutex); if (!entry) { errno = EBADMSG; return NULL; } session = service->new_session(entry->seq, entry->conn); if (!session) return NULL; session->passive = 1; entry->session = session; session->target = target; session->conn = entry->conn; session->seq = entry->seq++; session->out = NULL; session->in = NULL; timeout = Communicator::first_timeout_recv(session); mpoller_set_timeout(entry->sockfd, timeout, entry->mpoller); entry->state = CONN_STATE_RECEIVING; ((CommServiceTarget *)target)->incref(); in = session->message_in(); if (in) { in->poller_message_t::append = Communicator::append_message; in->entry = entry; session->in = in; } return in; } poller_message_t *Communicator::create_reply(void *context) { struct CommConnEntry *entry = (struct CommConnEntry *)context; CommSession *session; CommMessageIn *in; if (entry->state == CONN_STATE_IDLE) { pthread_mutex_lock(&entry->mutex); /* do nothing */ pthread_mutex_unlock(&entry->mutex); } if (entry->state != CONN_STATE_RECEIVING) { errno = EBADMSG; return NULL; } session = entry->session; in = session->message_in(); if (in) { in->poller_message_t::append = Communicator::append_message; in->entry = entry; session->in = in; } return in; } int Communicator::recv_request(const void *buf, size_t size, struct CommConnEntry *entry) { CommService *service = entry->service; CommTarget *target = entry->target; CommSession *session; CommMessageIn *in; size_t n; int ret; session = service->new_session(entry->seq, entry->conn); if (!session) return -1; session->passive = 1; entry->session = session; session->target = target; session->conn = entry->conn; session->seq = entry->seq++; session->out = NULL; session->in = NULL; entry->state = CONN_STATE_RECEIVING; ((CommServiceTarget *)target)->incref(); in = session->message_in(); if (in) { in->entry = entry; session->in = in; do { n = size; ret = in->append(buf, &n); if (ret == 0) { size -= n; buf = (const char *)buf + n; } else if (ret < 0) { entry->error = errno; entry->state = CONN_STATE_ERROR; } else entry->state = CONN_STATE_SUCCESS; } while (ret == 0 && size > 0); } return 0; } int Communicator::partial_written(size_t n, void *context) { struct CommConnEntry *entry = (struct CommConnEntry *)context; CommSession *session = entry->session; int timeout; timeout = Communicator::next_timeout(session); mpoller_set_timeout(entry->sockfd, timeout, entry->mpoller); return 0; } void *Communicator::accept(const struct sockaddr *addr, socklen_t addrlen, int sockfd, void *context) { CommService *service = (CommService *)context; CommServiceTarget *target = new CommServiceTarget; if (target) { if (target->init(addr, addrlen, 0, service->response_timeout) >= 0) { service->incref(); target->service = service; target->sockfd = sockfd; target->ref = 1; return target; } delete target; } close(sockfd); return NULL; } void *Communicator::recvfrom(const struct sockaddr *addr, socklen_t addrlen, const void *buf, size_t size, void *context) { CommService *service = (CommService *)context; struct CommConnEntry *entry; CommServiceTarget *target; void *result; int sockfd; sockfd = dup(service->listen_fd); if (sockfd >= 0) { result = Communicator::accept(addr, addrlen, sockfd, context); if (result) { target = (CommServiceTarget *)result; entry = Communicator::accept_conn(target, service); if (entry) { if (Communicator::recv_request(buf, size, entry) >= 0) return entry; __release_conn(entry); } else close(sockfd); target->decref(); } } return NULL; } void Communicator::callback(struct poller_result *res, void *context) { msgqueue_t *msgqueue = (msgqueue_t *)context; msgqueue_put(res, msgqueue); } int Communicator::create_handler_threads(size_t handler_threads) { struct thrdpool_task task = { .routine = Communicator::handler_thread_routine, .context = this }; size_t i; this->thrdpool = thrdpool_create(handler_threads, 0); if (this->thrdpool) { for (i = 0; i < handler_threads; i++) { if (thrdpool_schedule(&task, this->thrdpool) < 0) break; } if (i == handler_threads) return 0; msgqueue_set_nonblock(this->msgqueue); thrdpool_destroy(NULL, this->thrdpool); } return -1; } int Communicator::create_poller(size_t poller_threads) { struct poller_params params = { .max_open_files = (size_t)sysconf(_SC_OPEN_MAX), .callback = Communicator::callback, }; if ((ssize_t)params.max_open_files < 0) return -1; this->msgqueue = msgqueue_create(16 * 1024, sizeof (struct poller_result)); if (this->msgqueue) { params.context = this->msgqueue; this->mpoller = mpoller_create(¶ms, poller_threads); if (this->mpoller) { if (mpoller_start(this->mpoller) >= 0) return 0; mpoller_destroy(this->mpoller); } msgqueue_destroy(this->msgqueue); } return -1; } int Communicator::init(size_t poller_threads, size_t handler_threads) { if (poller_threads == 0) { errno = EINVAL; return -1; } if (this->create_poller(poller_threads) >= 0) { if (this->create_handler_threads(handler_threads) >= 0) { this->stop_flag = 0; return 0; } mpoller_stop(this->mpoller); mpoller_destroy(this->mpoller); msgqueue_destroy(this->msgqueue); } return -1; } void Communicator::deinit() { this->stop_flag = 1; mpoller_stop(this->mpoller); msgqueue_set_nonblock(this->msgqueue); thrdpool_destroy(NULL, this->thrdpool); mpoller_destroy(this->mpoller); msgqueue_destroy(this->msgqueue); } int Communicator::nonblock_connect(CommTarget *target) { int sockfd = target->create_connect_fd(); if (sockfd >= 0) { if (__set_fd_nonblock(sockfd) >= 0) { if (connect(sockfd, target->addr, target->addrlen) >= 0 || errno == EINPROGRESS) { return sockfd; } } close(sockfd); } return -1; } struct CommConnEntry *Communicator::launch_conn(CommSession *session, CommTarget *target) { struct CommConnEntry *entry; int sockfd; int ret; sockfd = Communicator::nonblock_connect(target); if (sockfd >= 0) { entry = (struct CommConnEntry *)malloc(sizeof (struct CommConnEntry)); if (entry) { ret = pthread_mutex_init(&entry->mutex, NULL); if (ret == 0) { entry->conn = target->new_connection(sockfd); if (entry->conn) { entry->seq = 0; entry->mpoller = NULL; entry->service = NULL; entry->target = target; entry->session = session; entry->ssl = NULL; entry->sockfd = sockfd; entry->state = CONN_STATE_CONNECTING; entry->ref = 1; return entry; } pthread_mutex_destroy(&entry->mutex); } else errno = ret; free(entry); } close(sockfd); } return NULL; } int Communicator::request_idle_conn(CommSession *session, CommTarget *target) { struct CommConnEntry *entry; struct list_head *pos; int ret = -1; while (1) { pthread_mutex_lock(&target->mutex); if (!list_empty(&target->idle_list)) { pos = target->idle_list.next; entry = list_entry(pos, struct CommConnEntry, list); list_del(pos); pthread_mutex_lock(&entry->mutex); } else entry = NULL; pthread_mutex_unlock(&target->mutex); if (!entry) { errno = ENOENT; return -1; } if (mpoller_set_timeout(entry->sockfd, -1, this->mpoller) >= 0) break; entry->state = CONN_STATE_CLOSING; pthread_mutex_unlock(&entry->mutex); } entry->session = session; session->conn = entry->conn; session->seq = entry->seq++; session->out = session->message_out(); if (session->out) ret = this->send_message(entry); if (ret < 0) { entry->error = errno; mpoller_del(entry->sockfd, this->mpoller); entry->state = CONN_STATE_ERROR; ret = 1; } pthread_mutex_unlock(&entry->mutex); return ret; } int Communicator::request_new_conn(CommSession *session, CommTarget *target) { struct CommConnEntry *entry; struct poller_data data; int timeout; entry = Communicator::launch_conn(session, target); if (entry) { entry->mpoller = this->mpoller; session->conn = entry->conn; session->seq = entry->seq++; data.operation = PD_OP_CONNECT; data.fd = entry->sockfd; data.ssl = NULL; data.context = entry; timeout = session->target->connect_timeout; if (mpoller_add(&data, timeout, this->mpoller) >= 0) return 0; __release_conn(entry); } return -1; } int Communicator::request(CommSession *session, CommTarget *target) { int errno_bak; if (session->passive) { errno = EINVAL; return -1; } errno_bak = errno; session->target = target; session->out = NULL; session->in = NULL; if (this->request_idle_conn(session, target) < 0) { if (this->request_new_conn(session, target) < 0) { session->conn = NULL; session->seq = 0; return -1; } } errno = errno_bak; return 0; } int Communicator::nonblock_listen(CommService *service) { int sockfd = service->create_listen_fd(); int ret; if (sockfd >= 0) { if (__set_fd_nonblock(sockfd) >= 0) { if (__bind_sockaddr(sockfd, service->bind_addr, service->addrlen) >= 0) { ret = listen(sockfd, SOMAXCONN); if (ret >= 0 || errno == EOPNOTSUPP) { service->reliable = (ret >= 0); return sockfd; } } } close(sockfd); } return -1; } int Communicator::bind(CommService *service) { struct poller_data data; int errno_bak = errno; int sockfd; sockfd = this->nonblock_listen(service); if (sockfd >= 0) { service->listen_fd = sockfd; service->ref = 1; data.fd = sockfd; data.context = service; data.result = NULL; if (service->reliable) { data.operation = PD_OP_LISTEN; data.accept = Communicator::accept; } else { data.operation = PD_OP_RECVFROM; data.recvfrom = Communicator::recvfrom; } if (mpoller_add(&data, service->listen_timeout, this->mpoller) >= 0) { errno = errno_bak; return 0; } close(sockfd); } return -1; } void Communicator::unbind(CommService *service) { int errno_bak = errno; if (mpoller_del(service->listen_fd, this->mpoller) < 0) { /* Error occurred on listen_fd or Communicator::deinit() called. */ this->shutdown_service(service); errno = errno_bak; } } int Communicator::reply_reliable(CommSession *session, CommTarget *target) { struct CommConnEntry *entry; struct list_head *pos; int ret = -1; pthread_mutex_lock(&target->mutex); if (!list_empty(&target->idle_list)) { pos = target->idle_list.next; entry = list_entry(pos, struct CommConnEntry, list); list_del(pos); session->out = session->message_out(); if (session->out) ret = this->send_message(entry); if (ret < 0) { entry->error = errno; mpoller_del(entry->sockfd, this->mpoller); entry->state = CONN_STATE_ERROR; ret = 1; } } else errno = ENOENT; pthread_mutex_unlock(&target->mutex); return ret; } int Communicator::reply_message_unreliable(struct CommConnEntry *entry) { struct iovec vectors[ENCODE_IOV_MAX]; int cnt; cnt = entry->session->out->encode(vectors, ENCODE_IOV_MAX); if ((unsigned int)cnt > ENCODE_IOV_MAX) { if (cnt > ENCODE_IOV_MAX) errno = EOVERFLOW; return -1; } if (cnt > 0) { struct msghdr message = { .msg_name = entry->target->addr, .msg_namelen = entry->target->addrlen, .msg_iov = vectors, #ifdef __linux__ .msg_iovlen = (size_t)cnt, #else .msg_iovlen = cnt, #endif }; if (sendmsg(entry->sockfd, &message, 0) < 0) return -1; } return 0; } int Communicator::reply_unreliable(CommSession *session, CommTarget *target) { struct CommConnEntry *entry; struct list_head *pos; if (!list_empty(&target->idle_list)) { pos = target->idle_list.next; entry = list_entry(pos, struct CommConnEntry, list); list_del(pos); session->out = session->message_out(); if (session->out) { if (this->reply_message_unreliable(entry) >= 0) return 0; } __release_conn(entry); ((CommServiceTarget *)target)->decref(); } else errno = ENOENT; return -1; } int Communicator::reply(CommSession *session) { struct CommConnEntry *entry; CommServiceTarget *target; int errno_bak; int ret; if (!session->passive) { errno = EINVAL; return -1; } if (session->out) { errno = ENOENT; return -1; } errno_bak = errno; target = (CommServiceTarget *)session->target; if (target->service->reliable) ret = this->reply_reliable(session, target); else ret = this->reply_unreliable(session, target); if (ret == 0) { entry = session->in->entry; session->handle(CS_STATE_SUCCESS, 0); if (__sync_sub_and_fetch(&entry->ref, 1) == 0) { __release_conn(entry); target->decref(); } } else if (ret < 0) return -1; errno = errno_bak; return 0; } int Communicator::push(const void *buf, size_t size, CommSession *session) { CommMessageIn *in = session->in; pthread_mutex_t *mutex; int ret; if (!in) { errno = ENOENT; return -1; } if (session->passive) mutex = &session->target->mutex; else mutex = &in->entry->mutex; pthread_mutex_lock(mutex); if ((!session->passive || session->target->has_idle_conn()) && in->entry->session == session) { ret = in->inner()->feedback(buf, size); } else { errno = ENOENT; ret = -1; } pthread_mutex_unlock(mutex); return ret; } int Communicator::shutdown(CommSession *session) { CommServiceTarget *target; if (!session->passive) { errno = EINVAL; return -1; } target = (CommServiceTarget *)session->target; if (session->out || !target->shutdown()) { errno = ENOENT; return -1; } return 0; } int Communicator::sleep(SleepSession *session) { struct timespec value; if (session->duration(&value) >= 0) { if (mpoller_add_timer(&value, session, &session->timer, &session->index, this->mpoller) >= 0) { return 0; } } return -1; } int Communicator::unsleep(SleepSession *session) { return mpoller_del_timer(session->timer, session->index, this->mpoller); } int Communicator::is_handler_thread() const { return thrdpool_in_pool(this->thrdpool); } extern "C" void __thrdpool_schedule(const struct thrdpool_task *, void *, thrdpool_t *); int Communicator::increase_handler_thread() { void *buf = malloc(4 * sizeof (void *)); if (buf) { if (thrdpool_increase(this->thrdpool) >= 0) { struct thrdpool_task task = { .routine = Communicator::handler_thread_routine, .context = this }; __thrdpool_schedule(&task, buf, this->thrdpool); return 0; } free(buf); } return -1; } int Communicator::decrease_handler_thread() { struct poller_result *res; size_t size; size = sizeof (struct poller_result) + sizeof (void *); res = (struct poller_result *)malloc(size); if (res) { res->data.operation = -1; msgqueue_put_head(res, this->msgqueue); return 0; } return -1; } #ifdef __linux__ void Communicator::shutdown_io_service(IOService *service) { pthread_mutex_lock(&service->mutex); close(service->event_fd); service->event_fd = -1; pthread_mutex_unlock(&service->mutex); service->decref(); } int Communicator::io_bind(IOService *service) { struct poller_data data; int event_fd; event_fd = service->create_event_fd(); if (event_fd >= 0) { if (__set_fd_nonblock(event_fd) >= 0) { service->ref = 1; data.operation = PD_OP_EVENT; data.fd = event_fd; data.event = IOService::aio_finish; data.context = service; data.result = NULL; if (mpoller_add(&data, -1, this->mpoller) >= 0) { service->event_fd = event_fd; return 0; } } close(event_fd); } return -1; } void Communicator::io_unbind(IOService *service) { int errno_bak = errno; if (mpoller_del(service->event_fd, this->mpoller) < 0) { /* Error occurred on event_fd or Communicator::deinit() called. */ this->shutdown_io_service(service); errno = errno_bak; } } #else void Communicator::shutdown_io_service(IOService *service) { pthread_mutex_lock(&service->mutex); close(service->pipe_fd[0]); close(service->pipe_fd[1]); service->pipe_fd[0] = -1; service->pipe_fd[1] = -1; pthread_mutex_unlock(&service->mutex); service->decref(); } int Communicator::io_bind(IOService *service) { struct poller_data data; int pipe_fd[2]; if (service->create_pipe_fd(pipe_fd) >= 0) { if (__set_fd_nonblock(pipe_fd[0]) >= 0) { service->ref = 1; data.operation = PD_OP_NOTIFY; data.fd = pipe_fd[0]; data.notify = IOService::aio_finish; data.context = service; data.result = NULL; if (mpoller_add(&data, -1, this->mpoller) >= 0) { service->pipe_fd[0] = pipe_fd[0]; service->pipe_fd[1] = pipe_fd[1]; return 0; } } close(pipe_fd[0]); close(pipe_fd[1]); } return -1; } void Communicator::io_unbind(IOService *service) { int errno_bak = errno; if (mpoller_del(service->pipe_fd[0], this->mpoller) < 0) { /* Error occurred on pipe_fd or Communicator::deinit() called. */ this->shutdown_io_service(service); errno = errno_bak; } } #endif workflow-0.11.8/src/kernel/Communicator.h000066400000000000000000000215731476003635400203530ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _COMMUNICATOR_H_ #define _COMMUNICATOR_H_ #include #include #include #include #include #include #include #include "list.h" #include "poller.h" class CommConnection { public: virtual ~CommConnection() { } }; class CommTarget { public: int init(const struct sockaddr *addr, socklen_t addrlen, int connect_timeout, int response_timeout); void deinit(); public: void get_addr(const struct sockaddr **addr, socklen_t *addrlen) const { *addr = this->addr; *addrlen = this->addrlen; } int has_idle_conn() const { return !list_empty(&this->idle_list); } protected: void set_ssl(SSL_CTX *ssl_ctx, int ssl_connect_timeout) { this->ssl_ctx = ssl_ctx; this->ssl_connect_timeout = ssl_connect_timeout; } SSL_CTX *get_ssl_ctx() const { return this->ssl_ctx; } private: virtual int create_connect_fd() { return socket(this->addr->sa_family, SOCK_STREAM, 0); } virtual CommConnection *new_connection(int connect_fd) { return new CommConnection; } virtual int init_ssl(SSL *ssl) { return 0; } public: virtual void release() { } private: struct sockaddr *addr; socklen_t addrlen; int connect_timeout; int response_timeout; int ssl_connect_timeout; SSL_CTX *ssl_ctx; private: struct list_head idle_list; pthread_mutex_t mutex; public: virtual ~CommTarget() { } friend class CommServiceTarget; friend class Communicator; }; class CommMessageOut { private: virtual int encode(struct iovec vectors[], int max) = 0; public: virtual ~CommMessageOut() { } friend class Communicator; }; class CommMessageIn : private poller_message_t { private: virtual int append(const void *buf, size_t *size) = 0; protected: /* Send small packet while receiving. Call only in append(). */ virtual int feedback(const void *buf, size_t size); /* In append(), reset the begin time of receiving to current time. */ virtual void renew(); /* Return the deepest wrapped message. */ virtual CommMessageIn *inner() { return this; } private: struct CommConnEntry *entry; public: virtual ~CommMessageIn() { } friend class Communicator; }; #define CS_STATE_SUCCESS 0 #define CS_STATE_ERROR 1 #define CS_STATE_STOPPED 2 #define CS_STATE_TOREPLY 3 /* for service session only. */ class CommSession { private: virtual CommMessageOut *message_out() = 0; virtual CommMessageIn *message_in() = 0; virtual int send_timeout() { return -1; } virtual int receive_timeout() { return -1; } virtual int keep_alive_timeout() { return 0; } virtual int first_timeout() { return 0; } virtual void handle(int state, int error) = 0; protected: CommTarget *get_target() const { return this->target; } CommConnection *get_connection() const { return this->conn; } CommMessageOut *get_message_out() const { return this->out; } CommMessageIn *get_message_in() const { return this->in; } long long get_seq() const { return this->seq; } private: CommTarget *target; CommConnection *conn; CommMessageOut *out; CommMessageIn *in; long long seq; private: struct timespec begin_time; int timeout; int passive; public: CommSession() { this->passive = 0; } virtual ~CommSession(); friend class CommMessageIn; friend class Communicator; }; class CommService { public: int init(const struct sockaddr *bind_addr, socklen_t addrlen, int listen_timeout, int response_timeout); void deinit(); int drain(int max); public: void get_addr(const struct sockaddr **addr, socklen_t *addrlen) const { *addr = this->bind_addr; *addrlen = this->addrlen; } protected: void set_ssl(SSL_CTX *ssl_ctx, int ssl_accept_timeout) { this->ssl_ctx = ssl_ctx; this->ssl_accept_timeout = ssl_accept_timeout; } SSL_CTX *get_ssl_ctx() const { return this->ssl_ctx; } private: virtual CommSession *new_session(long long seq, CommConnection *conn) = 0; virtual void handle_stop(int error) { } virtual void handle_unbound() = 0; private: virtual int create_listen_fd() { return socket(this->bind_addr->sa_family, SOCK_STREAM, 0); } virtual CommConnection *new_connection(int accept_fd) { return new CommConnection; } virtual int init_ssl(SSL *ssl) { return 0; } private: struct sockaddr *bind_addr; socklen_t addrlen; int listen_timeout; int response_timeout; int ssl_accept_timeout; SSL_CTX *ssl_ctx; private: void incref(); void decref(); private: int reliable; int listen_fd; int ref; private: struct list_head keep_alive_list; pthread_mutex_t mutex; public: virtual ~CommService() { } friend class CommServiceTarget; friend class Communicator; }; #define SS_STATE_COMPLETE 0 #define SS_STATE_ERROR 1 #define SS_STATE_DISRUPTED 2 class SleepSession { private: virtual int duration(struct timespec *value) = 0; virtual void handle(int state, int error) = 0; private: void *timer; int index; public: virtual ~SleepSession() { } friend class Communicator; }; #ifdef __linux__ # include "IOService_linux.h" #else # include "IOService_thread.h" #endif class Communicator { public: int init(size_t poller_threads, size_t handler_threads); void deinit(); int request(CommSession *session, CommTarget *target); int reply(CommSession *session); int push(const void *buf, size_t size, CommSession *session); int shutdown(CommSession *session); int bind(CommService *service); void unbind(CommService *service); int sleep(SleepSession *session); int unsleep(SleepSession *session); int io_bind(IOService *service); void io_unbind(IOService *service); public: int is_handler_thread() const; int increase_handler_thread(); int decrease_handler_thread(); private: struct __mpoller *mpoller; struct __msgqueue *msgqueue; struct __thrdpool *thrdpool; int stop_flag; private: int create_poller(size_t poller_threads); int create_handler_threads(size_t handler_threads); void shutdown_service(CommService *service); void shutdown_io_service(IOService *service); int send_message_sync(struct iovec vectors[], int cnt, struct CommConnEntry *entry); int send_message_async(struct iovec vectors[], int cnt, struct CommConnEntry *entry); int send_message(struct CommConnEntry *entry); int request_new_conn(CommSession *session, CommTarget *target); int request_idle_conn(CommSession *session, CommTarget *target); int reply_message_unreliable(struct CommConnEntry *entry); int reply_reliable(CommSession *session, CommTarget *target); int reply_unreliable(CommSession *session, CommTarget *target); void handle_incoming_request(struct poller_result *res); void handle_incoming_reply(struct poller_result *res); void handle_request_result(struct poller_result *res); void handle_reply_result(struct poller_result *res); void handle_write_result(struct poller_result *res); void handle_read_result(struct poller_result *res); void handle_connect_result(struct poller_result *res); void handle_listen_result(struct poller_result *res); void handle_recvfrom_result(struct poller_result *res); void handle_ssl_accept_result(struct poller_result *res); void handle_sleep_result(struct poller_result *res); void handle_aio_result(struct poller_result *res); static void handler_thread_routine(void *context); static int nonblock_connect(CommTarget *target); static int nonblock_listen(CommService *service); static struct CommConnEntry *launch_conn(CommSession *session, CommTarget *target); static struct CommConnEntry *accept_conn(class CommServiceTarget *target, CommService *service); static int first_timeout(CommSession *session); static int next_timeout(CommSession *session); static int first_timeout_send(CommSession *session); static int first_timeout_recv(CommSession *session); static int append_message(const void *buf, size_t *size, poller_message_t *msg); static poller_message_t *create_request(void *context); static poller_message_t *create_reply(void *context); static int recv_request(const void *buf, size_t size, struct CommConnEntry *entry); static int partial_written(size_t n, void *context); static void *accept(const struct sockaddr *addr, socklen_t addrlen, int sockfd, void *context); static void *recvfrom(const struct sockaddr *addr, socklen_t addrlen, const void *buf, size_t size, void *context); static void callback(struct poller_result *res, void *context); public: virtual ~Communicator() { } }; #endif workflow-0.11.8/src/kernel/ExecRequest.h000066400000000000000000000025531476003635400201450ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _EXECREQUEST_H_ #define _EXECREQUEST_H_ #include "SubTask.h" #include "Executor.h" class ExecRequest : public SubTask, public ExecSession { public: ExecRequest(ExecQueue *queue, Executor *executor) { this->executor = executor; this->queue = queue; } ExecQueue *get_request_queue() const { return this->queue; } void set_request_queue(ExecQueue *queue) { this->queue = queue; } public: virtual void dispatch() { if (this->executor->request(this, this->queue) < 0) this->handle(ES_STATE_ERROR, errno); } protected: int state; int error; protected: ExecQueue *queue; Executor *executor; protected: virtual void handle(int state, int error) { this->state = state; this->error = error; this->subtask_done(); } }; #endif workflow-0.11.8/src/kernel/Executor.cc000066400000000000000000000065261476003635400176500ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include "list.h" #include "thrdpool.h" #include "Executor.h" struct ExecSessionEntry { struct list_head list; ExecSession *session; thrdpool_t *thrdpool; }; int ExecQueue::init() { int ret; ret = pthread_mutex_init(&this->mutex, NULL); if (ret == 0) { INIT_LIST_HEAD(&this->session_list); return 0; } errno = ret; return -1; } void ExecQueue::deinit() { pthread_mutex_destroy(&this->mutex); } int Executor::init(size_t nthreads) { this->thrdpool = thrdpool_create(nthreads, 0); if (this->thrdpool) return 0; return -1; } void Executor::deinit() { thrdpool_destroy(Executor::executor_cancel, this->thrdpool); } extern "C" void __thrdpool_schedule(const struct thrdpool_task *, void *, thrdpool_t *); void Executor::executor_thread_routine(void *context) { ExecQueue *queue = (ExecQueue *)context; struct ExecSessionEntry *entry; ExecSession *session; int empty; entry = list_entry(queue->session_list.next, struct ExecSessionEntry, list); pthread_mutex_lock(&queue->mutex); list_del(&entry->list); empty = list_empty(&queue->session_list); pthread_mutex_unlock(&queue->mutex); session = entry->session; if (!empty) { struct thrdpool_task task = { .routine = Executor::executor_thread_routine, .context = queue }; __thrdpool_schedule(&task, entry, entry->thrdpool); } else free(entry); session->execute(); session->handle(ES_STATE_FINISHED, 0); } void Executor::executor_cancel(const struct thrdpool_task *task) { ExecQueue *queue = (ExecQueue *)task->context; struct ExecSessionEntry *entry; struct list_head *pos, *tmp; ExecSession *session; list_for_each_safe(pos, tmp, &queue->session_list) { entry = list_entry(pos, struct ExecSessionEntry, list); list_del(pos); session = entry->session; free(entry); session->handle(ES_STATE_CANCELED, 0); } } int Executor::request(ExecSession *session, ExecQueue *queue) { struct ExecSessionEntry *entry; session->queue = queue; entry = (struct ExecSessionEntry *)malloc(sizeof (struct ExecSessionEntry)); if (entry) { entry->session = session; entry->thrdpool = this->thrdpool; pthread_mutex_lock(&queue->mutex); list_add_tail(&entry->list, &queue->session_list); if (queue->session_list.next == &entry->list) { struct thrdpool_task task = { .routine = Executor::executor_thread_routine, .context = queue }; if (thrdpool_schedule(&task, this->thrdpool) < 0) { list_del(&entry->list); free(entry); entry = NULL; } } pthread_mutex_unlock(&queue->mutex); } return -!entry; } int Executor::increase_thread() { return thrdpool_increase(this->thrdpool); } int Executor::decrease_thread() { return thrdpool_decrease(this->thrdpool); } workflow-0.11.8/src/kernel/Executor.h000066400000000000000000000031671476003635400175100ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _EXECUTOR_H_ #define _EXECUTOR_H_ #include #include #include "list.h" class ExecQueue { public: int init(); void deinit(); private: struct list_head session_list; pthread_mutex_t mutex; public: virtual ~ExecQueue() { } friend class Executor; }; #define ES_STATE_FINISHED 0 #define ES_STATE_ERROR 1 #define ES_STATE_CANCELED 2 class ExecSession { private: virtual void execute() = 0; virtual void handle(int state, int error) = 0; protected: ExecQueue *get_queue() const { return this->queue; } private: ExecQueue *queue; public: virtual ~ExecSession() { } friend class Executor; }; class Executor { public: int init(size_t nthreads); void deinit(); int request(ExecSession *session, ExecQueue *queue); public: int increase_thread(); int decrease_thread(); private: struct __thrdpool *thrdpool; private: static void executor_thread_routine(void *context); static void executor_cancel(const struct thrdpool_task *task); public: virtual ~Executor() { } }; #endif workflow-0.11.8/src/kernel/IORequest.h000066400000000000000000000022531476003635400175650ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _IOREQUEST_H_ #define _IOREQUEST_H_ #include #include "SubTask.h" #include "Communicator.h" class IORequest : public SubTask, public IOSession { public: IORequest(IOService *service) { this->service = service; } public: virtual void dispatch() { if (this->service->request(this) < 0) this->handle(IOS_STATE_ERROR, errno); } protected: int state; int error; protected: IOService *service; protected: virtual void handle(int state, int error) { this->state = state; this->error = error; this->subtask_done(); } }; #endif workflow-0.11.8/src/kernel/IOService_linux.cc000066400000000000000000000211231476003635400211070ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include #include #include #include "list.h" #include "IOService_linux.h" /* Linux async I/O interface from libaio.h */ typedef struct io_context *io_context_t; typedef enum io_iocb_cmd { IO_CMD_PREAD = 0, IO_CMD_PWRITE = 1, IO_CMD_FSYNC = 2, IO_CMD_FDSYNC = 3, IO_CMD_POLL = 5, IO_CMD_NOOP = 6, IO_CMD_PREADV = 7, IO_CMD_PWRITEV = 8, } io_iocb_cmd_t; /* little endian, 32 bits */ #if defined(__i386__) || (defined(__arm__) && !defined(__ARMEB__)) || \ defined(__sh__) || defined(__bfin__) || defined(__MIPSEL__) || \ defined(__cris__) || (defined(__riscv) && __riscv_xlen == 32) || \ (defined(__GNUC__) && defined(__BYTE_ORDER__) && \ __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ && __SIZEOF_LONG__ == 4) #define PADDED(x, y) x; unsigned y #define PADDEDptr(x, y) x; unsigned y #define PADDEDul(x, y) unsigned long x; unsigned y /* little endian, 64 bits */ #elif defined(__ia64__) || defined(__x86_64__) || defined(__alpha__) || \ (defined(__aarch64__) && defined(__AARCH64EL__)) || \ (defined(__riscv) && __riscv_xlen == 64) || \ (defined(__GNUC__) && defined(__BYTE_ORDER__) && \ __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ && __SIZEOF_LONG__ == 8) #define PADDED(x, y) x, y #define PADDEDptr(x, y) x #define PADDEDul(x, y) unsigned long x /* big endian, 64 bits */ #elif defined(__powerpc64__) || defined(__s390x__) || \ (defined(__sparc__) && defined(__arch64__)) || \ (defined(__aarch64__) && defined(__AARCH64EB__)) || \ (defined(__GNUC__) && defined(__BYTE_ORDER__) && \ __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ && __SIZEOF_LONG__ == 8) #define PADDED(x, y) unsigned y; x #define PADDEDptr(x,y) x #define PADDEDul(x, y) unsigned long x /* big endian, 32 bits */ #elif defined(__PPC__) || defined(__s390__) || \ (defined(__arm__) && defined(__ARMEB__)) || \ defined(__sparc__) || defined(__MIPSEB__) || defined(__m68k__) || \ defined(__hppa__) || defined(__frv__) || defined(__avr32__) || \ (defined(__GNUC__) && defined(__BYTE_ORDER__) && \ __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ && __SIZEOF_LONG__ == 4) #define PADDED(x, y) unsigned y; x #define PADDEDptr(x, y) unsigned y; x #define PADDEDul(x, y) unsigned y; unsigned long x #else #error endian? #endif struct io_iocb_poll { PADDED(int events, __pad1); }; /* result code is the set of result flags or -'ve errno */ struct io_iocb_sockaddr { struct sockaddr *addr; int len; }; /* result code is the length of the sockaddr, or -'ve errno */ struct io_iocb_common { PADDEDptr(void *buf, __pad1); PADDEDul(nbytes, __pad2); long long offset; long long __pad3; unsigned flags; unsigned resfd; }; /* result code is the amount read or -'ve errno */ struct io_iocb_vector { const struct iovec *vec; int nr; long long offset; }; /* result code is the amount read or -'ve errno */ struct iocb { PADDEDptr(void *data, __pad1); /* Return in the io completion event */ /* key: For use in identifying io requests */ /* aio_rw_flags: RWF_* flags (such as RWF_NOWAIT) */ PADDED(unsigned key, aio_rw_flags); short aio_lio_opcode; short aio_reqprio; int aio_fildes; union { struct io_iocb_common c; struct io_iocb_vector v; struct io_iocb_poll poll; struct io_iocb_sockaddr saddr; } u; }; struct io_event { PADDEDptr(void *data, __pad1); PADDEDptr(struct iocb *obj, __pad2); PADDEDul(res, __pad3); PADDEDul(res2, __pad4); }; #undef PADDED #undef PADDEDptr #undef PADDEDul /* Actual syscalls */ static inline int io_setup(int maxevents, io_context_t *ctxp) { return syscall(__NR_io_setup, maxevents, ctxp); } static inline int io_destroy(io_context_t ctx) { return syscall(__NR_io_destroy, ctx); } static inline int io_submit(io_context_t ctx, long nr, struct iocb *ios[]) { return syscall(__NR_io_submit, ctx, nr, ios); } static inline int io_cancel(io_context_t ctx, struct iocb *iocb, struct io_event *evt) { return syscall(__NR_io_cancel, ctx, iocb, evt); } static inline int io_getevents(io_context_t ctx_id, long min_nr, long nr, struct io_event *events, struct timespec *timeout) { return syscall(__NR_io_getevents, ctx_id, min_nr, nr, events, timeout); } static inline void io_set_eventfd(struct iocb *iocb, int eventfd) { iocb->u.c.flags |= (1 << 0) /* IOCB_FLAG_RESFD */; iocb->u.c.resfd = eventfd; } void IOSession::prep_pread(int fd, void *buf, size_t count, long long offset) { struct iocb *iocb = (struct iocb *)this->iocb_buf; memset(iocb, 0, sizeof(*iocb)); iocb->aio_fildes = fd; iocb->aio_lio_opcode = IO_CMD_PREAD; iocb->u.c.buf = buf; iocb->u.c.nbytes = count; iocb->u.c.offset = offset; } void IOSession::prep_pwrite(int fd, void *buf, size_t count, long long offset) { struct iocb *iocb = (struct iocb *)this->iocb_buf; memset(iocb, 0, sizeof(*iocb)); iocb->aio_fildes = fd; iocb->aio_lio_opcode = IO_CMD_PWRITE; iocb->u.c.buf = buf; iocb->u.c.nbytes = count; iocb->u.c.offset = offset; } void IOSession::prep_preadv(int fd, const struct iovec *iov, int iovcnt, long long offset) { struct iocb *iocb = (struct iocb *)this->iocb_buf; memset(iocb, 0, sizeof(*iocb)); iocb->aio_fildes = fd; iocb->aio_lio_opcode = IO_CMD_PREADV; iocb->u.c.buf = (void *)iov; iocb->u.c.nbytes = iovcnt; iocb->u.c.offset = offset; } void IOSession::prep_pwritev(int fd, const struct iovec *iov, int iovcnt, long long offset) { struct iocb *iocb = (struct iocb *)this->iocb_buf; memset(iocb, 0, sizeof(*iocb)); iocb->aio_fildes = fd; iocb->aio_lio_opcode = IO_CMD_PWRITEV; iocb->u.c.buf = (void *)iov; iocb->u.c.nbytes = iovcnt; iocb->u.c.offset = offset; } void IOSession::prep_fsync(int fd) { struct iocb *iocb = (struct iocb *)this->iocb_buf; memset(iocb, 0, sizeof(*iocb)); iocb->aio_fildes = fd; iocb->aio_lio_opcode = IO_CMD_FSYNC; } void IOSession::prep_fdsync(int fd) { struct iocb *iocb = (struct iocb *)this->iocb_buf; memset(iocb, 0, sizeof(*iocb)); iocb->aio_fildes = fd; iocb->aio_lio_opcode = IO_CMD_FDSYNC; } int IOService::init(int maxevents) { int ret; if (maxevents < 0) { errno = EINVAL; return -1; } this->io_ctx = NULL; if (io_setup(maxevents, &this->io_ctx) >= 0) { ret = pthread_mutex_init(&this->mutex, NULL); if (ret == 0) { INIT_LIST_HEAD(&this->session_list); this->event_fd = -1; return 0; } errno = ret; io_destroy(this->io_ctx); } return -1; } void IOService::deinit() { pthread_mutex_destroy(&this->mutex); io_destroy(this->io_ctx); } inline void IOService::incref() { __sync_add_and_fetch(&this->ref, 1); } void IOService::decref() { IOSession *session; struct io_event event; int state, error; if (__sync_sub_and_fetch(&this->ref, 1) == 0) { while (!list_empty(&this->session_list)) { if (io_getevents(this->io_ctx, 1, 1, &event, NULL) > 0) { session = (IOSession *)event.data; list_del(&session->list); session->res = event.res; if (session->res >= 0) { state = IOS_STATE_SUCCESS; error = 0; } else { state = IOS_STATE_ERROR; error = -session->res; } session->handle(state, error); } } this->handle_unbound(); } } int IOService::request(IOSession *session) { struct iocb *iocb = (struct iocb *)session->iocb_buf; int ret = -1; pthread_mutex_lock(&this->mutex); if (this->event_fd < 0) errno = ENOENT; else if (session->prepare() >= 0) { io_set_eventfd(iocb, this->event_fd); iocb->data = session; if (io_submit(this->io_ctx, 1, &iocb) > 0) { list_add_tail(&session->list, &this->session_list); ret = 0; } } pthread_mutex_unlock(&this->mutex); if (ret < 0) session->res = -errno; return ret; } void *IOService::aio_finish(void *context) { IOService *service = (IOService *)context; IOSession *session; struct io_event event; if (io_getevents(service->io_ctx, 1, 1, &event, NULL) > 0) { service->incref(); session = (IOSession *)event.data; session->res = event.res; return session; } return NULL; } workflow-0.11.8/src/kernel/IOService_linux.h000066400000000000000000000042051476003635400207530ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _IOSERVICE_LINUX_H_ #define _IOSERVICE_LINUX_H_ #include #include #include #include #include "list.h" #define IOS_STATE_SUCCESS 0 #define IOS_STATE_ERROR 1 class IOSession { private: virtual int prepare() = 0; virtual void handle(int state, int error) = 0; protected: /* prepare() has to call one of the the prep_ functions. */ void prep_pread(int fd, void *buf, size_t count, long long offset); void prep_pwrite(int fd, void *buf, size_t count, long long offset); void prep_preadv(int fd, const struct iovec *iov, int iovcnt, long long offset); void prep_pwritev(int fd, const struct iovec *iov, int iovcnt, long long offset); void prep_fsync(int fd); void prep_fdsync(int fd); protected: long get_res() const { return this->res; } private: char iocb_buf[64]; long res; private: struct list_head list; public: virtual ~IOSession() { } friend class IOService; friend class Communicator; }; class IOService { public: int init(int maxevents); void deinit(); int request(IOSession *session); private: virtual void handle_stop(int error) { } virtual void handle_unbound() = 0; private: virtual int create_event_fd() { return eventfd(0, 0); } private: struct io_context *io_ctx; private: void incref(); void decref(); private: int event_fd; int ref; private: struct list_head session_list; pthread_mutex_t mutex; private: static void *aio_finish(void *context); public: virtual ~IOService() { } friend class Communicator; }; #endif workflow-0.11.8/src/kernel/IOService_thread.cc000066400000000000000000000140361476003635400212240ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include #include "list.h" #include "IOService_thread.h" typedef enum io_iocb_cmd { IO_CMD_PREAD = 0, IO_CMD_PWRITE = 1, IO_CMD_FSYNC = 2, IO_CMD_FDSYNC = 3, IO_CMD_NOOP = 6, IO_CMD_PREADV = 7, IO_CMD_PWRITEV = 8, } io_iocb_cmd_t; void IOSession::prep_pread(int fd, void *buf, size_t count, long long offset) { this->fd = fd; this->op = IO_CMD_PREAD; this->buf = buf; this->count = count; this->offset = offset; } void IOSession::prep_pwrite(int fd, void *buf, size_t count, long long offset) { this->fd = fd; this->op = IO_CMD_PWRITE; this->buf = buf; this->count = count; this->offset = offset; } void IOSession::prep_preadv(int fd, const struct iovec *iov, int iovcnt, long long offset) { this->fd = fd; this->op = IO_CMD_PREADV; this->buf = (void *)iov; this->count = iovcnt; this->offset = offset; } void IOSession::prep_pwritev(int fd, const struct iovec *iov, int iovcnt, long long offset) { this->fd = fd; this->op = IO_CMD_PWRITEV; this->buf = (void *)iov; this->count = iovcnt; this->offset = offset; } void IOSession::prep_fsync(int fd) { this->fd = fd; this->op = IO_CMD_FSYNC; } void IOSession::prep_fdsync(int fd) { this->fd = fd; this->op = IO_CMD_FDSYNC; } int IOService::init(int maxevents) { void *p; int ret; if (maxevents <= 0) { errno = EINVAL; return -1; } ret = pthread_mutex_init(&this->mutex, NULL); if (ret) { errno = ret; return -1; } p = dlsym(RTLD_DEFAULT, "preadv"); if (p) this->preadv = (ssize_t (*)(int, const struct iovec *, int, off_t))p; else this->preadv = IOService::preadv_emul; p = dlsym(RTLD_DEFAULT, "pwritev"); if (p) this->pwritev = (ssize_t (*)(int, const struct iovec *, int, off_t))p; else this->pwritev = IOService::pwritev_emul; this->maxevents = maxevents; this->nevents = 0; INIT_LIST_HEAD(&this->session_list); this->pipe_fd[0] = -1; this->pipe_fd[1] = -1; return 0; } void IOService::deinit() { pthread_mutex_destroy(&this->mutex); } inline void IOService::incref() { __sync_add_and_fetch(&this->ref, 1); } void IOService::decref() { IOSession *session; int state, error; if (__sync_sub_and_fetch(&this->ref, 1) == 0) { while (!list_empty(&this->session_list)) { session = list_entry(this->session_list.next, IOSession, list); pthread_join(session->tid, NULL); list_del(&session->list); if (session->res >= 0) { state = IOS_STATE_SUCCESS; error = 0; } else { state = IOS_STATE_ERROR; error = -session->res; } session->handle(state, error); } pthread_mutex_lock(&this->mutex); /* Wait for detached threads. */ pthread_mutex_unlock(&this->mutex); this->handle_unbound(); } } int IOService::request(IOSession *session) { pthread_t tid; int ret = -1; pthread_mutex_lock(&this->mutex); if (this->pipe_fd[0] < 0) errno = ENOENT; else if (this->nevents >= this->maxevents) errno = EAGAIN; else if (session->prepare() >= 0) { session->service = this; ret = pthread_create(&tid, NULL, IOService::io_routine, session); if (ret == 0) { session->tid = tid; list_add_tail(&session->list, &this->session_list); this->nevents++; } else { errno = ret; ret = -1; } } pthread_mutex_unlock(&this->mutex); if (ret < 0) session->res = -errno; return ret; } #if _POSIX_SYNCHRONIZED_IO <= 0 static inline int fdatasync(int fd) { return fsync(fd); } #endif void *IOService::io_routine(void *arg) { IOSession *session = (IOSession *)arg; IOService *service = session->service; int fd = session->fd; ssize_t ret; switch (session->op) { case IO_CMD_PREAD: ret = pread(fd, session->buf, session->count, session->offset); break; case IO_CMD_PWRITE: ret = pwrite(fd, session->buf, session->count, session->offset); break; case IO_CMD_FSYNC: ret = fsync(fd); break; case IO_CMD_FDSYNC: ret = fdatasync(fd); break; case IO_CMD_PREADV: ret = service->preadv(fd, (const struct iovec *)session->buf, session->count, session->offset); break; case IO_CMD_PWRITEV: ret = service->pwritev(fd, (const struct iovec *)session->buf, session->count, session->offset); break; default: errno = EINVAL; ret = -1; break; } if (ret < 0) ret = -errno; session->res = ret; pthread_mutex_lock(&service->mutex); if (service->pipe_fd[1] >= 0) write(service->pipe_fd[1], &session, sizeof (void *)); service->nevents--; pthread_mutex_unlock(&service->mutex); return NULL; } void *IOService::aio_finish(void *ptr, void *context) { IOService *service = (IOService *)context; IOSession *session = (IOSession *)ptr; service->incref(); pthread_detach(session->tid); return session; } ssize_t IOService::preadv_emul(int fd, const struct iovec *iov, int iovcnt, off_t offset) { size_t total = 0; ssize_t n; int i; for (i = 0; i < iovcnt; i++) { n = pread(fd, iov[i].iov_base, iov[i].iov_len, offset); if (n < 0) return total == 0 ? -1 : total; total += n; if ((size_t)n < iov[i].iov_len) return total; offset += n; } return total; } ssize_t IOService::pwritev_emul(int fd, const struct iovec *iov, int iovcnt, off_t offset) { size_t total = 0; ssize_t n; int i; for (i = 0; i < iovcnt; i++) { n = pwrite(fd, iov[i].iov_base, iov[i].iov_len, offset); if (n < 0) return total == 0 ? -1 : total; total += n; if ((size_t)n < iov[i].iov_len) return total; offset += n; } return total; } workflow-0.11.8/src/kernel/IOService_thread.h000066400000000000000000000051361476003635400210670ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _IOSERVICE_THREAD_H_ #define _IOSERVICE_THREAD_H_ #include #include #include #include #include "list.h" #define IOS_STATE_SUCCESS 0 #define IOS_STATE_ERROR 1 class IOSession { private: virtual int prepare() = 0; virtual void handle(int state, int error) = 0; protected: /* prepare() has to call one of the the prep_ functions. */ void prep_pread(int fd, void *buf, size_t count, long long offset); void prep_pwrite(int fd, void *buf, size_t count, long long offset); void prep_preadv(int fd, const struct iovec *iov, int iovcnt, long long offset); void prep_pwritev(int fd, const struct iovec *iov, int iovcnt, long long offset); void prep_fsync(int fd); void prep_fdsync(int fd); protected: long get_res() const { return this->res; } private: int fd; int op; void *buf; size_t count; long long offset; long res; private: struct list_head list; class IOService *service; pthread_t tid; public: virtual ~IOSession() { } friend class IOService; friend class Communicator; }; class IOService { public: int init(int maxevents); void deinit(); int request(IOSession *session); private: virtual void handle_stop(int error) { } virtual void handle_unbound() = 0; private: virtual int create_pipe_fd(int pipe_fd[2]) { return pipe(pipe_fd); } private: int maxevents; int nevents; private: void incref(); void decref(); private: int pipe_fd[2]; int ref; private: struct list_head session_list; pthread_mutex_t mutex; private: static void *io_routine(void *arg); static void *aio_finish(void *ptr, void *context); private: static ssize_t preadv_emul(int fd, const struct iovec *iov, int iovcnt, off_t offset); static ssize_t pwritev_emul(int fd, const struct iovec *iov, int iovcnt, off_t offset); ssize_t (*preadv)(int, const struct iovec *, int, off_t); ssize_t (*pwritev)(int, const struct iovec *, int, off_t); public: virtual ~IOService() { } friend class Communicator; }; #endif workflow-0.11.8/src/kernel/SleepRequest.h000066400000000000000000000024551476003635400203320ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _SLEEPREQUEST_H_ #define _SLEEPREQUEST_H_ #include #include "SubTask.h" #include "Communicator.h" #include "CommScheduler.h" class SleepRequest : public SubTask, public SleepSession { public: SleepRequest(CommScheduler *scheduler) { this->scheduler = scheduler; } public: virtual void dispatch() { if (this->scheduler->sleep(this) < 0) this->handle(SS_STATE_ERROR, errno); } protected: int cancel() { return this->scheduler->unsleep(this); } protected: int state; int error; protected: CommScheduler *scheduler; protected: virtual void handle(int state, int error) { this->state = state; this->error = error; this->subtask_done(); } }; #endif workflow-0.11.8/src/kernel/SubTask.cc000066400000000000000000000023741476003635400174230ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include "SubTask.h" void SubTask::subtask_done() { SubTask *cur = this; ParallelTask *parent; while (1) { parent = cur->parent; cur = cur->done(); if (cur) { cur->parent = parent; cur->dispatch(); } else if (parent) { if (__sync_sub_and_fetch(&parent->nleft, 1) == 0) { cur = parent; continue; } } break; } } void ParallelTask::dispatch() { SubTask **end = this->subtasks + this->subtasks_nr; SubTask **p = this->subtasks; this->nleft = this->subtasks_nr; if (this->nleft != 0) { do { (*p)->parent = this; (*p)->dispatch(); } while (++p != end); } else this->subtask_done(); } workflow-0.11.8/src/kernel/SubTask.h000066400000000000000000000026551476003635400172670ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _SUBTASK_H_ #define _SUBTASK_H_ #include class SubTask { public: virtual void dispatch() = 0; private: virtual SubTask *done() = 0; protected: void subtask_done(); public: void *get_pointer() const { return this->pointer; } void set_pointer(void *pointer) { this->pointer = pointer; } private: class ParallelTask *parent; void *pointer; public: SubTask() { this->parent = NULL; this->pointer = NULL; } virtual ~SubTask() { } friend class ParallelTask; }; class ParallelTask : public SubTask { public: virtual void dispatch(); protected: SubTask **subtasks; size_t subtasks_nr; private: size_t nleft; public: ParallelTask(SubTask **subtasks, size_t n) { this->subtasks = subtasks; this->subtasks_nr = n; } virtual ~ParallelTask() { } friend class SubTask; }; #endif workflow-0.11.8/src/kernel/list.h000066400000000000000000000204171476003635400166620ustar00rootroot00000000000000#ifndef _LINUX_LIST_H #define _LINUX_LIST_H /* * Circular doubly linked list implementation. * * Some of the internal functions ("__xxx") are useful when * manipulating whole lists rather than single entries, as * sometimes we already know the next/prev entries and we can * generate better code by using them directly rather than * using the generic single-entry routines. */ struct list_head { struct list_head *next, *prev; }; #define LIST_HEAD_INIT(name) { &(name), &(name) } #define LIST_HEAD(name) \ struct list_head name = LIST_HEAD_INIT(name) /** * INIT_LIST_HEAD - Initialize a list_head structure * @list: list_head structure to be initialized. * * Initializes the list_head to point to itself. If it is a list header, * the result is an empty list. */ static inline void INIT_LIST_HEAD(struct list_head *list) { list->next = list; list->prev = list; } /* * Insert a new entry between two known consecutive entries. * * This is only for internal list manipulation where we know * the prev/next entries already! */ static inline void __list_add(struct list_head *entry, struct list_head *prev, struct list_head *next) { next->prev = entry; entry->next = next; entry->prev = prev; prev->next = entry; } /** * list_add - add a new entry * @entry: new entry to be added * @head: list head to add it after * * Insert a new entry after the specified head. * This is good for implementing stacks. */ static inline void list_add(struct list_head *entry, struct list_head *head) { __list_add(entry, head, head->next); } /** * list_add_tail - add a new entry * @entry: new entry to be added * @head: list head to add it before * * Insert a new entry before the specified head. * This is useful for implementing queues. */ static inline void list_add_tail(struct list_head *entry, struct list_head *head) { __list_add(entry, head->prev, head); } /* * Delete a list entry by making the prev/next entries * point to each other. * * This is only for internal list manipulation where we know * the prev/next entries already! */ static inline void __list_del(struct list_head *prev, struct list_head *next) { next->prev = prev; prev->next = next; } /** * list_del - deletes entry from list. * @entry: the element to delete from the list. * Note: list_empty() on entry does not return true after this, the entry is * in an undefined state. */ static inline void list_del(struct list_head *entry) { __list_del(entry->prev, entry->next); } /** * list_move - delete from one list and add as another's head * @entry: the entry to move * @head: the head that will precede our entry */ static inline void list_move(struct list_head *entry, struct list_head *head) { __list_del(entry->prev, entry->next); list_add(entry, head); } /** * list_move_tail - delete from one list and add as another's tail * @entry: the entry to move * @head: the head that will follow our entry */ static inline void list_move_tail(struct list_head *entry, struct list_head *head) { __list_del(entry->prev, entry->next); list_add_tail(entry, head); } /** * list_empty - tests whether a list is empty * @head: the list to test. */ static inline int list_empty(const struct list_head *head) { return head->next == head; } static inline void __list_splice(const struct list_head *list, struct list_head *prev, struct list_head *next) { struct list_head *first = list->next; struct list_head *last = list->prev; first->prev = prev; prev->next = first; last->next = next; next->prev = last; } /** * list_splice - join two lists * @list: the new list to add. * @head: the place to add it in the first list. */ static inline void list_splice(const struct list_head *list, struct list_head *head) { if (!list_empty(list)) __list_splice(list, head, head->next); } /** * list_splice_init - join two lists and reinitialise the emptied list. * @list: the new list to add. * @head: the place to add it in the first list. * * The list at @list is reinitialised */ static inline void list_splice_init(struct list_head *list, struct list_head *head) { if (!list_empty(list)) { __list_splice(list, head, head->next); INIT_LIST_HEAD(list); } } /** * list_entry - get the struct for this entry * @ptr: the &struct list_head pointer. * @type: the type of the struct this is embedded in. * @member: the name of the list_struct within the struct. */ #define list_entry(ptr, type, member) \ ((type *)((char *)(ptr)-(unsigned long)(&((type *)0)->member))) /** * list_for_each - iterate over a list * @pos: the &struct list_head to use as a loop counter. * @head: the head for your list. */ #define list_for_each(pos, head) \ for (pos = (head)->next; pos != (head); pos = pos->next) /** * list_for_each_prev - iterate over a list backwards * @pos: the &struct list_head to use as a loop counter. * @head: the head for your list. */ #define list_for_each_prev(pos, head) \ for (pos = (head)->prev; pos != (head); pos = pos->prev) /** * list_for_each_safe - iterate over a list safe against removal of list entry * @pos: the &struct list_head to use as a loop counter. * @n: another &struct list_head to use as temporary storage * @head: the head for your list. */ #define list_for_each_safe(pos, n, head) \ for (pos = (head)->next, n = pos->next; pos != (head); \ pos = n, n = pos->next) /** * list_for_each_entry - iterate over list of given type * @pos: the type * to use as a loop counter. * @head: the head for your list. * @member: the name of the list_struct within the struct. */ #define list_for_each_entry(pos, head, member) \ for (pos = list_entry((head)->next, typeof (*pos), member); \ &pos->member != (head); \ pos = list_entry(pos->member.next, typeof (*pos), member)) /* * Singly linked list implementation. */ struct slist_node { struct slist_node *next; }; struct slist_head { struct slist_node first, *last; }; #define SLIST_HEAD_INIT(name) { { (struct slist_node *)0 }, &(name).first } #define SLIST_HEAD(name) \ struct slist_head name = SLIST_HEAD_INIT(name) static inline void INIT_SLIST_HEAD(struct slist_head *list) { list->first.next = (struct slist_node *)0; list->last = &list->first; } static inline void slist_add_after(struct slist_node *entry, struct slist_node *prev, struct slist_head *list) { entry->next = prev->next; prev->next = entry; if (!entry->next) list->last = entry; } static inline void slist_add_head(struct slist_node *entry, struct slist_head *list) { slist_add_after(entry, &list->first, list); } static inline void slist_add_tail(struct slist_node *entry, struct slist_head *list) { entry->next = (struct slist_node *)0; list->last->next = entry; list->last = entry; } static inline void slist_del_after(struct slist_node *prev, struct slist_head *list) { prev->next = prev->next->next; if (!prev->next) list->last = prev; } static inline void slist_del_head(struct slist_head *list) { slist_del_after(&list->first, list); } static inline int slist_empty(const struct slist_head *list) { return !list->first.next; } static inline void __slist_splice(const struct slist_head *list, struct slist_node *prev, struct slist_head *head) { list->last->next = prev->next; prev->next = list->first.next; if (!list->last->next) head->last = list->last; } static inline void slist_splice(const struct slist_head *list, struct slist_node *prev, struct slist_head *head) { if (!slist_empty(list)) __slist_splice(list, prev, head); } static inline void slist_splice_init(struct slist_head *list, struct slist_node *prev, struct slist_head *head) { if (!slist_empty(list)) { __slist_splice(list, prev, head); INIT_SLIST_HEAD(list); } } #define slist_entry(ptr, type, member) \ ((type *)((char *)(ptr)-(unsigned long)(&((type *)0)->member))) #define slist_for_each(pos, head) \ for (pos = (head)->first.next; pos; pos = pos->next) #define slist_for_each_safe(pos, prev, head) \ for (prev = &(head)->first, pos = prev->next; pos; \ prev = prev->next == pos ? pos : prev, pos = prev->next) #define slist_for_each_entry(pos, head, member) \ for (pos = slist_entry((head)->first.next, typeof (*pos), member); \ &pos->member != (struct slist_node *)0; \ pos = slist_entry(pos->member.next, typeof (*pos), member)) #endif workflow-0.11.8/src/kernel/mpoller.c000066400000000000000000000045011476003635400173500ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include "poller.h" #include "mpoller.h" extern poller_t *__poller_create(void **, const struct poller_params *); extern void __poller_destroy(poller_t *); static int __mpoller_create(const struct poller_params *params, mpoller_t *mpoller) { void **nodes_buf = (void **)calloc(params->max_open_files, sizeof (void *)); unsigned int i; if (nodes_buf) { for (i = 0; i < mpoller->nthreads; i++) { mpoller->poller[i] = __poller_create(nodes_buf, params); if (!mpoller->poller[i]) break; } if (i == mpoller->nthreads) { mpoller->nodes_buf = nodes_buf; return 0; } while (i > 0) __poller_destroy(mpoller->poller[--i]); free(nodes_buf); } return -1; } mpoller_t *mpoller_create(const struct poller_params *params, size_t nthreads) { mpoller_t *mpoller; size_t size; if (nthreads == 0) nthreads = 1; size = offsetof(mpoller_t, poller) + nthreads * sizeof (void *); mpoller = (mpoller_t *)malloc(size); if (mpoller) { mpoller->nthreads = (unsigned int)nthreads; if (__mpoller_create(params, mpoller) >= 0) return mpoller; free(mpoller); } return NULL; } int mpoller_start(mpoller_t *mpoller) { size_t i; for (i = 0; i < mpoller->nthreads; i++) { if (poller_start(mpoller->poller[i]) < 0) break; } if (i == mpoller->nthreads) return 0; while (i > 0) poller_stop(mpoller->poller[--i]); return -1; } void mpoller_stop(mpoller_t *mpoller) { size_t i; for (i = 0; i < mpoller->nthreads; i++) poller_stop(mpoller->poller[i]); } void mpoller_destroy(mpoller_t *mpoller) { size_t i; for (i = 0; i < mpoller->nthreads; i++) __poller_destroy(mpoller->poller[i]); free(mpoller->nodes_buf); free(mpoller); } workflow-0.11.8/src/kernel/mpoller.h000066400000000000000000000044571476003635400173670ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _MPOLLER_H_ #define _MPOLLER_H_ #include #include "poller.h" typedef struct __mpoller mpoller_t; #ifdef __cplusplus extern "C" { #endif mpoller_t *mpoller_create(const struct poller_params *params, size_t nthreads); int mpoller_start(mpoller_t *mpoller); void mpoller_stop(mpoller_t *mpoller); void mpoller_destroy(mpoller_t *mpoller); #ifdef __cplusplus } #endif struct __mpoller { void **nodes_buf; unsigned int nthreads; poller_t *poller[1]; }; static inline int mpoller_add(const struct poller_data *data, int timeout, mpoller_t *mpoller) { int index = (unsigned int)data->fd % mpoller->nthreads; return poller_add(data, timeout, mpoller->poller[index]); } static inline int mpoller_del(int fd, mpoller_t *mpoller) { int index = (unsigned int)fd % mpoller->nthreads; return poller_del(fd, mpoller->poller[index]); } static inline int mpoller_mod(const struct poller_data *data, int timeout, mpoller_t *mpoller) { int index = (unsigned int)data->fd % mpoller->nthreads; return poller_mod(data, timeout, mpoller->poller[index]); } static inline int mpoller_set_timeout(int fd, int timeout, mpoller_t *mpoller) { int index = (unsigned int)fd % mpoller->nthreads; return poller_set_timeout(fd, timeout, mpoller->poller[index]); } static inline int mpoller_add_timer(const struct timespec *value, void *context, void **timer, int *index, mpoller_t *mpoller) { static unsigned int n = 0; *index = n++ % mpoller->nthreads; return poller_add_timer(value, context, timer, mpoller->poller[*index]); } static inline int mpoller_del_timer(void *timer, int index, mpoller_t *mpoller) { return poller_del_timer(timer, mpoller->poller[index]); } #endif workflow-0.11.8/src/kernel/msgqueue.c000066400000000000000000000114711476003635400175350ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ /* * This message queue originates from the project of Sogou C++ Workflow: * https://github.com/sogou/workflow * * The idea of this implementation is quite simple and obvious. When the * get_list is not empty, the consumer takes a message. Otherwise the consumer * waits till put_list is not empty, and swap two lists. This method performs * well when the queue is very busy, and the number of consumers is big. */ #include #include #include #include "msgqueue.h" struct __msgqueue { size_t msg_max; size_t msg_cnt; int linkoff; int nonblock; void *head1; void *head2; void **get_head; void **put_head; void **put_tail; pthread_mutex_t get_mutex; pthread_mutex_t put_mutex; pthread_cond_t get_cond; pthread_cond_t put_cond; }; void msgqueue_set_nonblock(msgqueue_t *queue) { queue->nonblock = 1; pthread_mutex_lock(&queue->put_mutex); pthread_cond_signal(&queue->get_cond); pthread_cond_broadcast(&queue->put_cond); pthread_mutex_unlock(&queue->put_mutex); } void msgqueue_set_block(msgqueue_t *queue) { queue->nonblock = 0; } void msgqueue_put(void *msg, msgqueue_t *queue) { void **link = (void **)((char *)msg + queue->linkoff); *link = NULL; pthread_mutex_lock(&queue->put_mutex); while (queue->msg_cnt > queue->msg_max - 1 && !queue->nonblock) pthread_cond_wait(&queue->put_cond, &queue->put_mutex); *queue->put_tail = link; queue->put_tail = link; queue->msg_cnt++; pthread_mutex_unlock(&queue->put_mutex); pthread_cond_signal(&queue->get_cond); } void msgqueue_put_head(void *msg, msgqueue_t *queue) { void **link = (void **)((char *)msg + queue->linkoff); pthread_mutex_lock(&queue->put_mutex); while (*queue->get_head) { if (pthread_mutex_trylock(&queue->get_mutex) == 0) { pthread_mutex_unlock(&queue->put_mutex); *link = *queue->get_head; *queue->get_head = link; pthread_mutex_unlock(&queue->get_mutex); return; } } while (queue->msg_cnt > queue->msg_max - 1 && !queue->nonblock) pthread_cond_wait(&queue->put_cond, &queue->put_mutex); *link = *queue->put_head; if (*link == NULL) queue->put_tail = link; *queue->put_head = link; queue->msg_cnt++; pthread_mutex_unlock(&queue->put_mutex); pthread_cond_signal(&queue->get_cond); } static size_t __msgqueue_swap(msgqueue_t *queue) { void **get_head = queue->get_head; size_t cnt; pthread_mutex_lock(&queue->put_mutex); while (queue->msg_cnt == 0 && !queue->nonblock) pthread_cond_wait(&queue->get_cond, &queue->put_mutex); cnt = queue->msg_cnt; if (cnt > queue->msg_max - 1) pthread_cond_broadcast(&queue->put_cond); queue->get_head = queue->put_head; queue->put_head = get_head; queue->put_tail = get_head; queue->msg_cnt = 0; pthread_mutex_unlock(&queue->put_mutex); return cnt; } void *msgqueue_get(msgqueue_t *queue) { void *msg; pthread_mutex_lock(&queue->get_mutex); if (*queue->get_head || __msgqueue_swap(queue) > 0) { msg = (char *)*queue->get_head - queue->linkoff; *queue->get_head = *(void **)*queue->get_head; } else msg = NULL; pthread_mutex_unlock(&queue->get_mutex); return msg; } msgqueue_t *msgqueue_create(size_t maxlen, int linkoff) { msgqueue_t *queue = (msgqueue_t *)malloc(sizeof (msgqueue_t)); int ret; if (!queue) return NULL; ret = pthread_mutex_init(&queue->get_mutex, NULL); if (ret == 0) { ret = pthread_mutex_init(&queue->put_mutex, NULL); if (ret == 0) { ret = pthread_cond_init(&queue->get_cond, NULL); if (ret == 0) { ret = pthread_cond_init(&queue->put_cond, NULL); if (ret == 0) { queue->msg_max = maxlen; queue->linkoff = linkoff; queue->head1 = NULL; queue->head2 = NULL; queue->get_head = &queue->head1; queue->put_head = &queue->head2; queue->put_tail = &queue->head2; queue->msg_cnt = 0; queue->nonblock = 0; return queue; } pthread_cond_destroy(&queue->get_cond); } pthread_mutex_destroy(&queue->put_mutex); } pthread_mutex_destroy(&queue->get_mutex); } errno = ret; free(queue); return NULL; } void msgqueue_destroy(msgqueue_t *queue) { pthread_cond_destroy(&queue->put_cond); pthread_cond_destroy(&queue->get_cond); pthread_mutex_destroy(&queue->put_mutex); pthread_mutex_destroy(&queue->get_mutex); free(queue); } workflow-0.11.8/src/kernel/msgqueue.h000066400000000000000000000027321476003635400175420ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _MSGQUEUE_H_ #define _MSGQUEUE_H_ #include typedef struct __msgqueue msgqueue_t; #ifdef __cplusplus extern "C" { #endif /* A simple implementation of message queue. The max pending messages may * reach two times 'maxlen' when the queue is in blocking mode, and infinite * in nonblocking mode. 'linkoff' is the offset from the head of each message, * where spaces of one pointer size should be available for internal usage. * 'linkoff' can be positive or negative or zero. */ msgqueue_t *msgqueue_create(size_t maxlen, int linkoff); void *msgqueue_get(msgqueue_t *queue); void msgqueue_put(void *msg, msgqueue_t *queue); void msgqueue_put_head(void *msg, msgqueue_t *queue); void msgqueue_set_nonblock(msgqueue_t *queue); void msgqueue_set_block(msgqueue_t *queue); void msgqueue_destroy(msgqueue_t *queue); #ifdef __cplusplus } #endif #endif workflow-0.11.8/src/kernel/poller.c000066400000000000000000001016511476003635400171770ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #ifdef __linux__ # include # include #else # include # undef LIST_HEAD # undef SLIST_HEAD #endif #include #include #include #include #include #include #include #include #include "list.h" #include "rbtree.h" #include "poller.h" #define POLLER_BUFSIZE (256 * 1024) #define POLLER_EVENTS_MAX 256 struct __poller_node { int state; int error; struct poller_data data; #pragma pack(1) union { struct list_head list; struct rb_node rb; }; #pragma pack() char in_rbtree; char removed; int event; struct timespec timeout; struct __poller_node *res; }; struct __poller { size_t max_open_files; void (*callback)(struct poller_result *, void *); void *context; pthread_t tid; int pfd; int timerfd; int pipe_rd; int pipe_wr; int stopped; struct rb_root timeo_tree; struct rb_node *tree_first; struct rb_node *tree_last; struct list_head timeo_list; struct list_head no_timeo_list; struct __poller_node **nodes; pthread_mutex_t mutex; char buf[POLLER_BUFSIZE]; }; #ifdef __linux__ static inline int __poller_create_pfd() { return epoll_create(1); } static inline int __poller_close_pfd(int fd) { return close(fd); } static inline int __poller_add_fd(int fd, int event, void *data, poller_t *poller) { struct epoll_event ev = { .events = event, .data = { .ptr = data } }; return epoll_ctl(poller->pfd, EPOLL_CTL_ADD, fd, &ev); } static inline int __poller_del_fd(int fd, int event, poller_t *poller) { return epoll_ctl(poller->pfd, EPOLL_CTL_DEL, fd, NULL); } static inline int __poller_mod_fd(int fd, int old_event, int new_event, void *data, poller_t *poller) { struct epoll_event ev = { .events = new_event, .data = { .ptr = data } }; return epoll_ctl(poller->pfd, EPOLL_CTL_MOD, fd, &ev); } static inline int __poller_create_timerfd() { return timerfd_create(CLOCK_MONOTONIC, 0); } static inline int __poller_close_timerfd(int fd) { return close(fd); } static inline int __poller_add_timerfd(int fd, poller_t *poller) { struct epoll_event ev = { .events = EPOLLIN | EPOLLET, .data = { .ptr = NULL } }; return epoll_ctl(poller->pfd, EPOLL_CTL_ADD, fd, &ev); } static inline int __poller_set_timerfd(int fd, const struct timespec *abstime, poller_t *poller) { struct itimerspec timer = { .it_interval = { }, .it_value = *abstime }; return timerfd_settime(fd, TFD_TIMER_ABSTIME, &timer, NULL); } typedef struct epoll_event __poller_event_t; static inline int __poller_wait(__poller_event_t *events, int maxevents, poller_t *poller) { return epoll_wait(poller->pfd, events, maxevents, -1); } static inline void *__poller_event_data(const __poller_event_t *event) { return event->data.ptr; } #else /* BSD, macOS */ static inline int __poller_create_pfd() { return kqueue(); } static inline int __poller_close_pfd(int fd) { return close(fd); } static inline int __poller_add_fd(int fd, int event, void *data, poller_t *poller) { struct kevent ev; EV_SET(&ev, fd, event, EV_ADD, 0, 0, data); return kevent(poller->pfd, &ev, 1, NULL, 0, NULL); } static inline int __poller_del_fd(int fd, int event, poller_t *poller) { struct kevent ev; EV_SET(&ev, fd, event, EV_DELETE, 0, 0, NULL); return kevent(poller->pfd, &ev, 1, NULL, 0, NULL); } static inline int __poller_mod_fd(int fd, int old_event, int new_event, void *data, poller_t *poller) { struct kevent ev[2]; EV_SET(&ev[0], fd, old_event, EV_DELETE, 0, 0, NULL); EV_SET(&ev[1], fd, new_event, EV_ADD, 0, 0, data); return kevent(poller->pfd, ev, 2, NULL, 0, NULL); } static inline int __poller_create_timerfd() { return 0; } static inline int __poller_close_timerfd(int fd) { return 0; } static inline int __poller_add_timerfd(int fd, poller_t *poller) { return 0; } static int __poller_set_timerfd(int fd, const struct timespec *abstime, poller_t *poller) { struct timespec curtime; long long nseconds; struct kevent ev; int flags; if (abstime->tv_sec || abstime->tv_nsec) { clock_gettime(CLOCK_MONOTONIC, &curtime); nseconds = 1000000000LL * (abstime->tv_sec - curtime.tv_sec); nseconds += abstime->tv_nsec - curtime.tv_nsec; flags = EV_ADD; } else { nseconds = 0; flags = EV_DELETE; } EV_SET(&ev, fd, EVFILT_TIMER, flags, NOTE_NSECONDS, nseconds, NULL); return kevent(poller->pfd, &ev, 1, NULL, 0, NULL); } typedef struct kevent __poller_event_t; static inline int __poller_wait(__poller_event_t *events, int maxevents, poller_t *poller) { return kevent(poller->pfd, NULL, 0, events, maxevents, NULL); } static inline void *__poller_event_data(const __poller_event_t *event) { return event->udata; } #define EPOLLIN EVFILT_READ #define EPOLLOUT EVFILT_WRITE #define EPOLLET 0 #endif static inline long __timeout_cmp(const struct __poller_node *node1, const struct __poller_node *node2) { long ret = node1->timeout.tv_sec - node2->timeout.tv_sec; if (ret == 0) ret = node1->timeout.tv_nsec - node2->timeout.tv_nsec; return ret; } static void __poller_tree_insert(struct __poller_node *node, poller_t *poller) { struct rb_node **p = &poller->timeo_tree.rb_node; struct rb_node *parent = NULL; struct __poller_node *entry; entry = rb_entry(poller->tree_last, struct __poller_node, rb); if (!*p) { poller->tree_first = &node->rb; poller->tree_last = &node->rb; } else if (__timeout_cmp(node, entry) >= 0) { parent = poller->tree_last; p = &parent->rb_right; poller->tree_last = &node->rb; } else { do { parent = *p; entry = rb_entry(*p, struct __poller_node, rb); if (__timeout_cmp(node, entry) < 0) p = &(*p)->rb_left; else p = &(*p)->rb_right; } while (*p); if (p == &poller->tree_first->rb_left) poller->tree_first = &node->rb; } node->in_rbtree = 1; rb_link_node(&node->rb, parent, p); rb_insert_color(&node->rb, &poller->timeo_tree); } static inline void __poller_tree_erase(struct __poller_node *node, poller_t *poller) { if (&node->rb == poller->tree_first) poller->tree_first = rb_next(&node->rb); if (&node->rb == poller->tree_last) poller->tree_last = rb_prev(&node->rb); rb_erase(&node->rb, &poller->timeo_tree); node->in_rbtree = 0; } static int __poller_remove_node(struct __poller_node *node, poller_t *poller) { int removed; pthread_mutex_lock(&poller->mutex); removed = node->removed; if (!removed) { poller->nodes[node->data.fd] = NULL; if (node->in_rbtree) __poller_tree_erase(node, poller); else list_del(&node->list); __poller_del_fd(node->data.fd, node->event, poller); } pthread_mutex_unlock(&poller->mutex); return removed; } static int __poller_append_message(const void *buf, size_t *n, struct __poller_node *node, poller_t *poller) { poller_message_t *msg = node->data.message; struct __poller_node *res; int ret; if (!msg) { res = (struct __poller_node *)malloc(sizeof (struct __poller_node)); if (!res) return -1; msg = node->data.create_message(node->data.context); if (!msg) { free(res); return -1; } node->data.message = msg; node->res = res; } else res = node->res; ret = msg->append(buf, n, msg); if (ret > 0) { res->data = node->data; res->error = 0; res->state = PR_ST_SUCCESS; poller->callback((struct poller_result *)res, poller->context); node->data.message = NULL; node->res = NULL; } return ret; } static int __poller_handle_ssl_error(struct __poller_node *node, int ret, poller_t *poller) { int error = SSL_get_error(node->data.ssl, ret); int event; switch (error) { case SSL_ERROR_WANT_READ: event = EPOLLIN | EPOLLET; break; case SSL_ERROR_WANT_WRITE: event = EPOLLOUT | EPOLLET; break; default: errno = -error; case SSL_ERROR_SYSCALL: return -1; } if (event == node->event) return 0; pthread_mutex_lock(&poller->mutex); if (!node->removed) { ret = __poller_mod_fd(node->data.fd, node->event, event, node, poller); if (ret >= 0) node->event = event; } else ret = 0; pthread_mutex_unlock(&poller->mutex); return ret; } static void __poller_handle_read(struct __poller_node *node, poller_t *poller) { ssize_t nleft; size_t n; char *p; while (1) { p = poller->buf; if (!node->data.ssl) { nleft = read(node->data.fd, p, POLLER_BUFSIZE); if (nleft < 0) { if (errno == EAGAIN) return; } } else { nleft = SSL_read(node->data.ssl, p, POLLER_BUFSIZE); if (nleft < 0) { if (__poller_handle_ssl_error(node, nleft, poller) >= 0) return; } } if (nleft <= 0) break; do { n = nleft; if (__poller_append_message(p, &n, node, poller) >= 0) { nleft -= n; p += n; } else nleft = -1; } while (nleft > 0); if (nleft < 0) break; if (node->removed) return; } if (__poller_remove_node(node, poller)) return; if (nleft == 0) { node->error = 0; node->state = PR_ST_FINISHED; } else { node->error = errno; node->state = PR_ST_ERROR; } free(node->res); poller->callback((struct poller_result *)node, poller->context); } #ifndef IOV_MAX # ifdef UIO_MAXIOV # define IOV_MAX UIO_MAXIOV # else # define IOV_MAX 1024 # endif #endif static void __poller_handle_write(struct __poller_node *node, poller_t *poller) { struct iovec *iov = node->data.write_iov; size_t count = 0; ssize_t nleft; int iovcnt; int ret; while (node->data.iovcnt > 0) { if (!node->data.ssl) { iovcnt = node->data.iovcnt; if (iovcnt > IOV_MAX) iovcnt = IOV_MAX; nleft = writev(node->data.fd, iov, iovcnt); if (nleft < 0) { ret = errno == EAGAIN ? 0 : -1; break; } } else if (iov->iov_len > 0) { nleft = SSL_write(node->data.ssl, iov->iov_base, iov->iov_len); if (nleft <= 0) { ret = __poller_handle_ssl_error(node, nleft, poller); break; } } else nleft = 0; count += nleft; do { if (nleft >= iov->iov_len) { nleft -= iov->iov_len; iov->iov_base = (char *)iov->iov_base + iov->iov_len; iov->iov_len = 0; iov++; node->data.iovcnt--; } else { iov->iov_base = (char *)iov->iov_base + nleft; iov->iov_len -= nleft; break; } } while (node->data.iovcnt > 0); } node->data.write_iov = iov; if (node->data.iovcnt > 0 && ret >= 0) { if (count == 0) return; if (node->data.partial_written(count, node->data.context) >= 0) return; } if (__poller_remove_node(node, poller)) return; if (node->data.iovcnt == 0) { node->error = 0; node->state = PR_ST_FINISHED; } else { node->error = errno; node->state = PR_ST_ERROR; } poller->callback((struct poller_result *)node, poller->context); } static void __poller_handle_listen(struct __poller_node *node, poller_t *poller) { struct __poller_node *res = node->res; struct sockaddr_storage ss; struct sockaddr *addr = (struct sockaddr *)&ss; socklen_t addrlen; void *result; int sockfd; while (1) { addrlen = sizeof (struct sockaddr_storage); sockfd = accept(node->data.fd, addr, &addrlen); if (sockfd < 0) { if (errno == EAGAIN || errno == EMFILE || errno == ENFILE) return; else if (errno == ECONNABORTED) continue; else break; } result = node->data.accept(addr, addrlen, sockfd, node->data.context); if (!result) break; res->data = node->data; res->data.result = result; res->error = 0; res->state = PR_ST_SUCCESS; poller->callback((struct poller_result *)res, poller->context); res = (struct __poller_node *)malloc(sizeof (struct __poller_node)); node->res = res; if (!res) break; if (node->removed) return; } if (__poller_remove_node(node, poller)) return; node->error = errno; node->state = PR_ST_ERROR; free(node->res); poller->callback((struct poller_result *)node, poller->context); } static void __poller_handle_connect(struct __poller_node *node, poller_t *poller) { socklen_t len = sizeof (int); int error; if (getsockopt(node->data.fd, SOL_SOCKET, SO_ERROR, &error, &len) < 0) error = errno; if (__poller_remove_node(node, poller)) return; if (error == 0) { node->error = 0; node->state = PR_ST_FINISHED; } else { node->error = error; node->state = PR_ST_ERROR; } poller->callback((struct poller_result *)node, poller->context); } static void __poller_handle_recvfrom(struct __poller_node *node, poller_t *poller) { struct __poller_node *res = node->res; struct sockaddr_storage ss; struct sockaddr *addr = (struct sockaddr *)&ss; socklen_t addrlen; void *result; ssize_t n; while (1) { addrlen = sizeof (struct sockaddr_storage); n = recvfrom(node->data.fd, poller->buf, POLLER_BUFSIZE, 0, addr, &addrlen); if (n < 0) { if (errno == EAGAIN) return; else break; } result = node->data.recvfrom(addr, addrlen, poller->buf, n, node->data.context); if (!result) break; res->data = node->data; res->data.result = result; res->error = 0; res->state = PR_ST_SUCCESS; poller->callback((struct poller_result *)res, poller->context); res = (struct __poller_node *)malloc(sizeof (struct __poller_node)); node->res = res; if (!res) break; if (node->removed) return; } if (__poller_remove_node(node, poller)) return; node->error = errno; node->state = PR_ST_ERROR; free(node->res); poller->callback((struct poller_result *)node, poller->context); } static void __poller_handle_ssl_accept(struct __poller_node *node, poller_t *poller) { int ret = SSL_accept(node->data.ssl); if (ret <= 0) { if (__poller_handle_ssl_error(node, ret, poller) >= 0) return; } if (__poller_remove_node(node, poller)) return; if (ret > 0) { node->error = 0; node->state = PR_ST_FINISHED; } else { node->error = errno; node->state = PR_ST_ERROR; } poller->callback((struct poller_result *)node, poller->context); } static void __poller_handle_ssl_connect(struct __poller_node *node, poller_t *poller) { int ret = SSL_connect(node->data.ssl); if (ret <= 0) { if (__poller_handle_ssl_error(node, ret, poller) >= 0) return; } if (__poller_remove_node(node, poller)) return; if (ret > 0) { node->error = 0; node->state = PR_ST_FINISHED; } else { node->error = errno; node->state = PR_ST_ERROR; } poller->callback((struct poller_result *)node, poller->context); } static void __poller_handle_ssl_shutdown(struct __poller_node *node, poller_t *poller) { int ret = SSL_shutdown(node->data.ssl); if (ret <= 0) { if (__poller_handle_ssl_error(node, ret, poller) >= 0) return; } if (__poller_remove_node(node, poller)) return; if (ret > 0) { node->error = 0; node->state = PR_ST_FINISHED; } else { node->error = errno; node->state = PR_ST_ERROR; } poller->callback((struct poller_result *)node, poller->context); } static void __poller_handle_event(struct __poller_node *node, poller_t *poller) { struct __poller_node *res = node->res; unsigned long long cnt = 0; unsigned long long value; void *result; ssize_t n; while (1) { n = read(node->data.fd, &value, sizeof (unsigned long long)); if (n == sizeof (unsigned long long)) cnt += value; else { if (n >= 0) errno = EINVAL; break; } } if (errno == EAGAIN) { while (1) { if (cnt == 0) return; cnt--; result = node->data.event(node->data.context); if (!result) break; res->data = node->data; res->data.result = result; res->error = 0; res->state = PR_ST_SUCCESS; poller->callback((struct poller_result *)res, poller->context); res = (struct __poller_node *)malloc(sizeof (struct __poller_node)); node->res = res; if (!res) break; if (node->removed) return; } } if (cnt != 0) write(node->data.fd, &cnt, sizeof (unsigned long long)); if (__poller_remove_node(node, poller)) return; node->error = errno; node->state = PR_ST_ERROR; free(node->res); poller->callback((struct poller_result *)node, poller->context); } static void __poller_handle_notify(struct __poller_node *node, poller_t *poller) { struct __poller_node *res = node->res; void *result; ssize_t n; while (1) { n = read(node->data.fd, &result, sizeof (void *)); if (n == sizeof (void *)) { result = node->data.notify(result, node->data.context); if (!result) break; res->data = node->data; res->data.result = result; res->error = 0; res->state = PR_ST_SUCCESS; poller->callback((struct poller_result *)res, poller->context); res = (struct __poller_node *)malloc(sizeof (struct __poller_node)); node->res = res; if (!res) break; if (node->removed) return; } else if (n < 0 && errno == EAGAIN) return; else { if (n > 0) errno = EINVAL; break; } } if (__poller_remove_node(node, poller)) return; if (n == 0) { node->error = 0; node->state = PR_ST_FINISHED; } else { node->error = errno; node->state = PR_ST_ERROR; } free(node->res); poller->callback((struct poller_result *)node, poller->context); } static int __poller_handle_pipe(poller_t *poller) { struct __poller_node **node = (struct __poller_node **)poller->buf; int stop = 0; int n; int i; n = read(poller->pipe_rd, node, POLLER_BUFSIZE) / sizeof (void *); for (i = 0; i < n; i++) { if (node[i]) { free(node[i]->res); poller->callback((struct poller_result *)node[i], poller->context); } else stop = 1; } return stop; } static void __poller_handle_timeout(const struct __poller_node *time_node, poller_t *poller) { struct __poller_node *node; struct list_head *pos, *tmp; LIST_HEAD(timeo_list); pthread_mutex_lock(&poller->mutex); list_for_each_safe(pos, tmp, &poller->timeo_list) { node = list_entry(pos, struct __poller_node, list); if (__timeout_cmp(node, time_node) > 0) break; if (node->data.fd >= 0) { poller->nodes[node->data.fd] = NULL; __poller_del_fd(node->data.fd, node->event, poller); } else node->removed = 1; list_move_tail(pos, &timeo_list); } while (poller->tree_first) { node = rb_entry(poller->tree_first, struct __poller_node, rb); if (__timeout_cmp(node, time_node) > 0) break; if (node->data.fd >= 0) { poller->nodes[node->data.fd] = NULL; __poller_del_fd(node->data.fd, node->event, poller); } else node->removed = 1; poller->tree_first = rb_next(poller->tree_first); rb_erase(&node->rb, &poller->timeo_tree); list_add_tail(&node->list, &timeo_list); if (!poller->tree_first) poller->tree_last = NULL; } pthread_mutex_unlock(&poller->mutex); list_for_each_safe(pos, tmp, &timeo_list) { node = list_entry(pos, struct __poller_node, list); if (node->data.fd >= 0) { node->error = ETIMEDOUT; node->state = PR_ST_ERROR; } else { node->error = 0; node->state = PR_ST_FINISHED; } free(node->res); poller->callback((struct poller_result *)node, poller->context); } } static void __poller_set_timer(poller_t *poller) { struct __poller_node *node = NULL; struct __poller_node *first; struct timespec abstime; pthread_mutex_lock(&poller->mutex); if (!list_empty(&poller->timeo_list)) node = list_entry(poller->timeo_list.next, struct __poller_node, list); if (poller->tree_first) { first = rb_entry(poller->tree_first, struct __poller_node, rb); if (!node || __timeout_cmp(first, node) < 0) node = first; } if (node) abstime = node->timeout; else { abstime.tv_sec = 0; abstime.tv_nsec = 0; } __poller_set_timerfd(poller->timerfd, &abstime, poller); pthread_mutex_unlock(&poller->mutex); } static void *__poller_thread_routine(void *arg) { poller_t *poller = (poller_t *)arg; __poller_event_t events[POLLER_EVENTS_MAX]; struct __poller_node time_node; struct __poller_node *node; int has_pipe_event; int nevents; int i; while (1) { __poller_set_timer(poller); nevents = __poller_wait(events, POLLER_EVENTS_MAX, poller); clock_gettime(CLOCK_MONOTONIC, &time_node.timeout); has_pipe_event = 0; for (i = 0; i < nevents; i++) { node = (struct __poller_node *)__poller_event_data(&events[i]); if (node <= (struct __poller_node *)1) { if (node == (struct __poller_node *)1) has_pipe_event = 1; continue; } switch (node->data.operation) { case PD_OP_READ: __poller_handle_read(node, poller); break; case PD_OP_WRITE: __poller_handle_write(node, poller); break; case PD_OP_LISTEN: __poller_handle_listen(node, poller); break; case PD_OP_CONNECT: __poller_handle_connect(node, poller); break; case PD_OP_RECVFROM: __poller_handle_recvfrom(node, poller); break; case PD_OP_SSL_ACCEPT: __poller_handle_ssl_accept(node, poller); break; case PD_OP_SSL_CONNECT: __poller_handle_ssl_connect(node, poller); break; case PD_OP_SSL_SHUTDOWN: __poller_handle_ssl_shutdown(node, poller); break; case PD_OP_EVENT: __poller_handle_event(node, poller); break; case PD_OP_NOTIFY: __poller_handle_notify(node, poller); break; } } if (has_pipe_event) { if (__poller_handle_pipe(poller)) break; } __poller_handle_timeout(&time_node, poller); } return NULL; } static int __poller_open_pipe(poller_t *poller) { int pipefd[2]; if (pipe(pipefd) >= 0) { if (__poller_add_fd(pipefd[0], EPOLLIN, (void *)1, poller) >= 0) { poller->pipe_rd = pipefd[0]; poller->pipe_wr = pipefd[1]; return 0; } close(pipefd[0]); close(pipefd[1]); } return -1; } static int __poller_create_timer(poller_t *poller) { int timerfd = __poller_create_timerfd(); if (timerfd >= 0) { if (__poller_add_timerfd(timerfd, poller) >= 0) { poller->timerfd = timerfd; return 0; } __poller_close_timerfd(timerfd); } return -1; } poller_t *__poller_create(void **nodes_buf, const struct poller_params *params) { poller_t *poller = (poller_t *)malloc(sizeof (poller_t)); int ret; if (!poller) return NULL; poller->pfd = __poller_create_pfd(); if (poller->pfd >= 0) { if (__poller_create_timer(poller) >= 0) { ret = pthread_mutex_init(&poller->mutex, NULL); if (ret == 0) { poller->nodes = (struct __poller_node **)nodes_buf; poller->max_open_files = params->max_open_files; poller->callback = params->callback; poller->context = params->context; poller->timeo_tree.rb_node = NULL; poller->tree_first = NULL; poller->tree_last = NULL; INIT_LIST_HEAD(&poller->timeo_list); INIT_LIST_HEAD(&poller->no_timeo_list); poller->stopped = 1; return poller; } errno = ret; __poller_close_timerfd(poller->timerfd); } __poller_close_pfd(poller->pfd); } free(poller); return NULL; } poller_t *poller_create(const struct poller_params *params) { void **nodes_buf = (void **)calloc(params->max_open_files, sizeof (void *)); poller_t *poller; if (nodes_buf) { poller = __poller_create(nodes_buf, params); if (poller) return poller; free(nodes_buf); } return NULL; } void __poller_destroy(poller_t *poller) { pthread_mutex_destroy(&poller->mutex); __poller_close_timerfd(poller->timerfd); __poller_close_pfd(poller->pfd); free(poller); } void poller_destroy(poller_t *poller) { free(poller->nodes); __poller_destroy(poller); } int poller_start(poller_t *poller) { pthread_t tid; int ret; pthread_mutex_lock(&poller->mutex); if (__poller_open_pipe(poller) >= 0) { ret = pthread_create(&tid, NULL, __poller_thread_routine, poller); if (ret == 0) { poller->tid = tid; poller->stopped = 0; } else { errno = ret; close(poller->pipe_wr); close(poller->pipe_rd); } } pthread_mutex_unlock(&poller->mutex); return -poller->stopped; } static void __poller_insert_node(struct __poller_node *node, poller_t *poller) { struct __poller_node *end; end = list_entry(poller->timeo_list.prev, struct __poller_node, list); if (list_empty(&poller->timeo_list)) { list_add(&node->list, &poller->timeo_list); end = rb_entry(poller->tree_first, struct __poller_node, rb); } else if (__timeout_cmp(node, end) >= 0) { list_add_tail(&node->list, &poller->timeo_list); return; } else { __poller_tree_insert(node, poller); if (&node->rb != poller->tree_first) return; end = list_entry(poller->timeo_list.next, struct __poller_node, list); } if (!poller->tree_first || __timeout_cmp(node, end) < 0) __poller_set_timerfd(poller->timerfd, &node->timeout, poller); } static void __poller_node_set_timeout(int timeout, struct __poller_node *node) { clock_gettime(CLOCK_MONOTONIC, &node->timeout); node->timeout.tv_sec += timeout / 1000; node->timeout.tv_nsec += timeout % 1000 * 1000000; if (node->timeout.tv_nsec >= 1000000000) { node->timeout.tv_nsec -= 1000000000; node->timeout.tv_sec++; } } static int __poller_data_get_event(int *event, const struct poller_data *data) { switch (data->operation) { case PD_OP_READ: *event = EPOLLIN | EPOLLET; return !!data->message; case PD_OP_WRITE: *event = EPOLLOUT | EPOLLET; return 0; case PD_OP_LISTEN: *event = EPOLLIN; return 1; case PD_OP_CONNECT: *event = EPOLLOUT | EPOLLET; return 0; case PD_OP_RECVFROM: *event = EPOLLIN | EPOLLET; return 1; case PD_OP_SSL_ACCEPT: *event = EPOLLIN | EPOLLET; return 0; case PD_OP_SSL_CONNECT: *event = EPOLLOUT | EPOLLET; return 0; case PD_OP_SSL_SHUTDOWN: *event = EPOLLOUT | EPOLLET; return 0; case PD_OP_EVENT: *event = EPOLLIN | EPOLLET; return 1; case PD_OP_NOTIFY: *event = EPOLLIN | EPOLLET; return 1; default: errno = EINVAL; return -1; } } static struct __poller_node *__poller_new_node(const struct poller_data *data, int timeout, poller_t *poller) { struct __poller_node *res = NULL; struct __poller_node *node; int need_res; int event; if ((size_t)data->fd >= poller->max_open_files) { errno = data->fd < 0 ? EBADF : EMFILE; return NULL; } need_res = __poller_data_get_event(&event, data); if (need_res < 0) return NULL; if (need_res) { res = (struct __poller_node *)malloc(sizeof (struct __poller_node)); if (!res) return NULL; } node = (struct __poller_node *)malloc(sizeof (struct __poller_node)); if (!node) { free(res); return NULL; } node->data = *data; node->event = event; node->in_rbtree = 0; node->removed = 0; node->res = res; if (timeout >= 0) __poller_node_set_timeout(timeout, node); return node; } int poller_add(const struct poller_data *data, int timeout, poller_t *poller) { struct __poller_node *node; node = __poller_new_node(data, timeout, poller); if (!node) return -1; pthread_mutex_lock(&poller->mutex); if (!poller->nodes[data->fd]) { if (__poller_add_fd(data->fd, node->event, node, poller) >= 0) { if (timeout >= 0) __poller_insert_node(node, poller); else list_add_tail(&node->list, &poller->no_timeo_list); poller->nodes[data->fd] = node; node = NULL; } } else errno = EEXIST; pthread_mutex_unlock(&poller->mutex); if (node == NULL) return 0; free(node->res); free(node); return -1; } int poller_del(int fd, poller_t *poller) { struct __poller_node *node; int stopped = 0; if ((size_t)fd >= poller->max_open_files) { errno = fd < 0 ? EBADF : EMFILE; return -1; } pthread_mutex_lock(&poller->mutex); node = poller->nodes[fd]; if (node) { poller->nodes[fd] = NULL; if (node->in_rbtree) __poller_tree_erase(node, poller); else list_del(&node->list); __poller_del_fd(fd, node->event, poller); node->error = 0; node->state = PR_ST_DELETED; stopped = poller->stopped; if (!stopped) { node->removed = 1; write(poller->pipe_wr, &node, sizeof (void *)); } } else errno = ENOENT; pthread_mutex_unlock(&poller->mutex); if (stopped) { free(node->res); poller->callback((struct poller_result *)node, poller->context); } return -!node; } int poller_mod(const struct poller_data *data, int timeout, poller_t *poller) { struct __poller_node *node; struct __poller_node *orig; int stopped = 0; node = __poller_new_node(data, timeout, poller); if (!node) return -1; pthread_mutex_lock(&poller->mutex); orig = poller->nodes[data->fd]; if (orig) { if (__poller_mod_fd(data->fd, orig->event, node->event, node, poller) >= 0) { if (orig->in_rbtree) __poller_tree_erase(orig, poller); else list_del(&orig->list); orig->error = 0; orig->state = PR_ST_MODIFIED; stopped = poller->stopped; if (!stopped) { orig->removed = 1; write(poller->pipe_wr, &orig, sizeof (void *)); } if (timeout >= 0) __poller_insert_node(node, poller); else list_add_tail(&node->list, &poller->no_timeo_list); poller->nodes[data->fd] = node; node = NULL; } } else errno = ENOENT; pthread_mutex_unlock(&poller->mutex); if (stopped) { free(orig->res); poller->callback((struct poller_result *)orig, poller->context); } if (node == NULL) return 0; free(node->res); free(node); return -1; } int poller_set_timeout(int fd, int timeout, poller_t *poller) { struct __poller_node time_node; struct __poller_node *node; if ((size_t)fd >= poller->max_open_files) { errno = fd < 0 ? EBADF : EMFILE; return -1; } if (timeout >= 0) __poller_node_set_timeout(timeout, &time_node); pthread_mutex_lock(&poller->mutex); node = poller->nodes[fd]; if (node) { if (node->in_rbtree) __poller_tree_erase(node, poller); else list_del(&node->list); if (timeout >= 0) { node->timeout = time_node.timeout; __poller_insert_node(node, poller); } else list_add_tail(&node->list, &poller->no_timeo_list); } else errno = ENOENT; pthread_mutex_unlock(&poller->mutex); return -!node; } int poller_add_timer(const struct timespec *value, void *context, void **timer, poller_t *poller) { struct __poller_node *node; if (value->tv_nsec < 0 || value->tv_nsec >= 1000000000) { errno = EINVAL; return -1; } node = (struct __poller_node *)malloc(sizeof (struct __poller_node)); if (node) { memset(&node->data, 0, sizeof (struct poller_data)); node->data.operation = PD_OP_TIMER; node->data.fd = -1; node->data.context = context; node->in_rbtree = 0; node->removed = 0; node->res = NULL; if (value->tv_sec >= 0) { clock_gettime(CLOCK_MONOTONIC, &node->timeout); node->timeout.tv_sec += value->tv_sec; node->timeout.tv_nsec += value->tv_nsec; if (node->timeout.tv_nsec >= 1000000000) { node->timeout.tv_nsec -= 1000000000; node->timeout.tv_sec++; } } *timer = node; pthread_mutex_lock(&poller->mutex); if (value->tv_sec >= 0) __poller_insert_node(node, poller); else list_add_tail(&node->list, &poller->no_timeo_list); pthread_mutex_unlock(&poller->mutex); return 0; } return -1; } int poller_del_timer(void *timer, poller_t *poller) { struct __poller_node *node = (struct __poller_node *)timer; int stopped = 0; pthread_mutex_lock(&poller->mutex); if (!node->removed) { node->removed = 1; if (node->in_rbtree) __poller_tree_erase(node, poller); else list_del(&node->list); node->error = 0; node->state = PR_ST_DELETED; stopped = poller->stopped; if (!stopped) write(poller->pipe_wr, &node, sizeof (void *)); } else { errno = ENOENT; node = NULL; } pthread_mutex_unlock(&poller->mutex); if (stopped) poller->callback((struct poller_result *)node, poller->context); return -!node; } void poller_stop(poller_t *poller) { struct __poller_node *node; struct list_head *pos, *tmp; LIST_HEAD(node_list); void *p = NULL; write(poller->pipe_wr, &p, sizeof (void *)); pthread_join(poller->tid, NULL); poller->stopped = 1; pthread_mutex_lock(&poller->mutex); close(poller->pipe_wr); __poller_handle_pipe(poller); close(poller->pipe_rd); poller->tree_first = NULL; poller->tree_last = NULL; while (poller->timeo_tree.rb_node) { node = rb_entry(poller->timeo_tree.rb_node, struct __poller_node, rb); rb_erase(&node->rb, &poller->timeo_tree); list_add(&node->list, &node_list); } list_splice_init(&poller->timeo_list, &node_list); list_splice_init(&poller->no_timeo_list, &node_list); list_for_each(pos, &node_list) { node = list_entry(pos, struct __poller_node, list); if (node->data.fd >= 0) { poller->nodes[node->data.fd] = NULL; __poller_del_fd(node->data.fd, node->event, poller); } else node->removed = 1; } pthread_mutex_unlock(&poller->mutex); list_for_each_safe(pos, tmp, &node_list) { node = list_entry(pos, struct __poller_node, list); node->error = 0; node->state = PR_ST_STOPPED; free(node->res); poller->callback((struct poller_result *)node, poller->context); } } workflow-0.11.8/src/kernel/poller.h000066400000000000000000000055211476003635400172030ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _POLLER_H_ #define _POLLER_H_ #include #include #include #include typedef struct __poller poller_t; typedef struct __poller_message poller_message_t; struct __poller_message { int (*append)(const void *, size_t *, poller_message_t *); char data[0]; }; struct poller_data { #define PD_OP_TIMER 0 #define PD_OP_READ 1 #define PD_OP_WRITE 2 #define PD_OP_LISTEN 3 #define PD_OP_CONNECT 4 #define PD_OP_RECVFROM 5 #define PD_OP_SSL_READ PD_OP_READ #define PD_OP_SSL_WRITE PD_OP_WRITE #define PD_OP_SSL_ACCEPT 6 #define PD_OP_SSL_CONNECT 7 #define PD_OP_SSL_SHUTDOWN 8 #define PD_OP_EVENT 9 #define PD_OP_NOTIFY 10 short operation; unsigned short iovcnt; int fd; SSL *ssl; union { poller_message_t *(*create_message)(void *); int (*partial_written)(size_t, void *); void *(*accept)(const struct sockaddr *, socklen_t, int, void *); void *(*recvfrom)(const struct sockaddr *, socklen_t, const void *, size_t, void *); void *(*event)(void *); void *(*notify)(void *, void *); }; void *context; union { poller_message_t *message; struct iovec *write_iov; void *result; }; }; struct poller_result { #define PR_ST_SUCCESS 0 #define PR_ST_FINISHED 1 #define PR_ST_ERROR 2 #define PR_ST_DELETED 3 #define PR_ST_MODIFIED 4 #define PR_ST_STOPPED 5 int state; int error; struct poller_data data; /* In callback, spaces of six pointers are available from here. */ }; struct poller_params { size_t max_open_files; void (*callback)(struct poller_result *, void *); void *context; }; #ifdef __cplusplus extern "C" { #endif poller_t *poller_create(const struct poller_params *params); int poller_start(poller_t *poller); int poller_add(const struct poller_data *data, int timeout, poller_t *poller); int poller_del(int fd, poller_t *poller); int poller_mod(const struct poller_data *data, int timeout, poller_t *poller); int poller_set_timeout(int fd, int timeout, poller_t *poller); int poller_add_timer(const struct timespec *value, void *context, void **timer, poller_t *poller); int poller_del_timer(void *timer, poller_t *poller); void poller_stop(poller_t *poller); void poller_destroy(poller_t *poller); #ifdef __cplusplus } #endif #endif workflow-0.11.8/src/kernel/rbtree.c000066400000000000000000000213431476003635400171640ustar00rootroot00000000000000/* Red Black Trees (C) 1999 Andrea Arcangeli (C) 2002 David Woodhouse This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA linux/lib/rbtree.c */ #include "rbtree.h" static void __rb_rotate_left(struct rb_node *node, struct rb_root *root) { struct rb_node *right = node->rb_right; if ((node->rb_right = right->rb_left)) right->rb_left->rb_parent = node; right->rb_left = node; if ((right->rb_parent = node->rb_parent)) { if (node == node->rb_parent->rb_left) node->rb_parent->rb_left = right; else node->rb_parent->rb_right = right; } else root->rb_node = right; node->rb_parent = right; } static void __rb_rotate_right(struct rb_node *node, struct rb_root *root) { struct rb_node *left = node->rb_left; if ((node->rb_left = left->rb_right)) left->rb_right->rb_parent = node; left->rb_right = node; if ((left->rb_parent = node->rb_parent)) { if (node == node->rb_parent->rb_right) node->rb_parent->rb_right = left; else node->rb_parent->rb_left = left; } else root->rb_node = left; node->rb_parent = left; } void rb_insert_color(struct rb_node *node, struct rb_root *root) { struct rb_node *parent, *gparent; while ((parent = node->rb_parent) && parent->rb_color == RB_RED) { gparent = parent->rb_parent; if (parent == gparent->rb_left) { { register struct rb_node *uncle = gparent->rb_right; if (uncle && uncle->rb_color == RB_RED) { uncle->rb_color = RB_BLACK; parent->rb_color = RB_BLACK; gparent->rb_color = RB_RED; node = gparent; continue; } } if (parent->rb_right == node) { register struct rb_node *tmp; __rb_rotate_left(parent, root); tmp = parent; parent = node; node = tmp; } parent->rb_color = RB_BLACK; gparent->rb_color = RB_RED; __rb_rotate_right(gparent, root); } else { { register struct rb_node *uncle = gparent->rb_left; if (uncle && uncle->rb_color == RB_RED) { uncle->rb_color = RB_BLACK; parent->rb_color = RB_BLACK; gparent->rb_color = RB_RED; node = gparent; continue; } } if (parent->rb_left == node) { register struct rb_node *tmp; __rb_rotate_right(parent, root); tmp = parent; parent = node; node = tmp; } parent->rb_color = RB_BLACK; gparent->rb_color = RB_RED; __rb_rotate_left(gparent, root); } } root->rb_node->rb_color = RB_BLACK; } static void __rb_erase_color(struct rb_node *node, struct rb_node *parent, struct rb_root *root) { struct rb_node *other; while ((!node || node->rb_color == RB_BLACK) && node != root->rb_node) { if (parent->rb_left == node) { other = parent->rb_right; if (other->rb_color == RB_RED) { other->rb_color = RB_BLACK; parent->rb_color = RB_RED; __rb_rotate_left(parent, root); other = parent->rb_right; } if ((!other->rb_left || other->rb_left->rb_color == RB_BLACK) && (!other->rb_right || other->rb_right->rb_color == RB_BLACK)) { other->rb_color = RB_RED; node = parent; parent = node->rb_parent; } else { if (!other->rb_right || other->rb_right->rb_color == RB_BLACK) { register struct rb_node *o_left; if ((o_left = other->rb_left)) o_left->rb_color = RB_BLACK; other->rb_color = RB_RED; __rb_rotate_right(other, root); other = parent->rb_right; } other->rb_color = parent->rb_color; parent->rb_color = RB_BLACK; if (other->rb_right) other->rb_right->rb_color = RB_BLACK; __rb_rotate_left(parent, root); node = root->rb_node; break; } } else { other = parent->rb_left; if (other->rb_color == RB_RED) { other->rb_color = RB_BLACK; parent->rb_color = RB_RED; __rb_rotate_right(parent, root); other = parent->rb_left; } if ((!other->rb_left || other->rb_left->rb_color == RB_BLACK) && (!other->rb_right || other->rb_right->rb_color == RB_BLACK)) { other->rb_color = RB_RED; node = parent; parent = node->rb_parent; } else { if (!other->rb_left || other->rb_left->rb_color == RB_BLACK) { register struct rb_node *o_right; if ((o_right = other->rb_right)) o_right->rb_color = RB_BLACK; other->rb_color = RB_RED; __rb_rotate_left(other, root); other = parent->rb_left; } other->rb_color = parent->rb_color; parent->rb_color = RB_BLACK; if (other->rb_left) other->rb_left->rb_color = RB_BLACK; __rb_rotate_right(parent, root); node = root->rb_node; break; } } } if (node) node->rb_color = RB_BLACK; } void rb_erase(struct rb_node *node, struct rb_root *root) { struct rb_node *child, *parent; int color; if (!node->rb_left) child = node->rb_right; else if (!node->rb_right) child = node->rb_left; else { struct rb_node *old = node, *left; node = node->rb_right; while ((left = node->rb_left)) node = left; child = node->rb_right; parent = node->rb_parent; color = node->rb_color; if (child) child->rb_parent = parent; if (parent) { if (parent->rb_left == node) parent->rb_left = child; else parent->rb_right = child; } else root->rb_node = child; if (node->rb_parent == old) parent = node; node->rb_parent = old->rb_parent; node->rb_color = old->rb_color; node->rb_right = old->rb_right; node->rb_left = old->rb_left; if (old->rb_parent) { if (old->rb_parent->rb_left == old) old->rb_parent->rb_left = node; else old->rb_parent->rb_right = node; } else root->rb_node = node; old->rb_left->rb_parent = node; if (old->rb_right) old->rb_right->rb_parent = node; goto color; } parent = node->rb_parent; color = node->rb_color; if (child) child->rb_parent = parent; if (parent) { if (parent->rb_left == node) parent->rb_left = child; else parent->rb_right = child; } else root->rb_node = child; color: if (color == RB_BLACK) __rb_erase_color(child, parent, root); } /* * This function returns the first node (in sort order) of the tree. */ struct rb_node *rb_first(struct rb_root *root) { struct rb_node *n; n = root->rb_node; if (!n) return (struct rb_node *)0; while (n->rb_left) n = n->rb_left; return n; } struct rb_node *rb_last(struct rb_root *root) { struct rb_node *n; n = root->rb_node; if (!n) return (struct rb_node *)0; while (n->rb_right) n = n->rb_right; return n; } struct rb_node *rb_next(struct rb_node *node) { /* If we have a right-hand child, go down and then left as far as we can. */ if (node->rb_right) { node = node->rb_right; while (node->rb_left) node = node->rb_left; return node; } /* No right-hand children. Everything down and left is smaller than us, so any 'next' node must be in the general direction of our parent. Go up the tree; any time the ancestor is a right-hand child of its parent, keep going up. First time it's a left-hand child of its parent, said parent is our 'next' node. */ while (node->rb_parent && node == node->rb_parent->rb_right) node = node->rb_parent; return node->rb_parent; } struct rb_node *rb_prev(struct rb_node *node) { /* If we have a left-hand child, go down and then right as far as we can. */ if (node->rb_left) { node = node->rb_left; while (node->rb_right) node = node->rb_right; return node; } /* No left-hand children. Go up till we find an ancestor which is a right-hand child of its parent */ while (node->rb_parent && node == node->rb_parent->rb_left) node = node->rb_parent; return node->rb_parent; } void rb_replace_node(struct rb_node *victim, struct rb_node *newnode, struct rb_root *root) { struct rb_node *parent = victim->rb_parent; /* Set the surrounding nodes to point to the replacement */ if (parent) { if (victim == parent->rb_left) parent->rb_left = newnode; else parent->rb_right = newnode; } else { root->rb_node = newnode; } if (victim->rb_left) victim->rb_left->rb_parent = newnode; if (victim->rb_right) victim->rb_right->rb_parent = newnode; /* Copy the pointers/colour from the victim to the replacement */ *newnode = *victim; } workflow-0.11.8/src/kernel/rbtree.h000066400000000000000000000100311476003635400171610ustar00rootroot00000000000000/* Red Black Trees (C) 1999 Andrea Arcangeli This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA linux/include/linux/rbtree.h To use rbtrees you'll have to implement your own insert and search cores. This will avoid us to use callbacks and to drop drammatically performances. I know it's not the cleaner way, but in C (not in C++) to get performances and genericity... Some example of insert and search follows here. The search is a plain normal search over an ordered tree. The insert instead must be implemented int two steps: as first thing the code must insert the element in order as a red leaf in the tree, then the support library function rb_insert_color() must be called. Such function will do the not trivial work to rebalance the rbtree if necessary. ----------------------------------------------------------------------- static inline struct page * rb_search_page_cache(struct inode * inode, unsigned long offset) { rb_node_t * n = inode->i_rb_page_cache.rb_node; struct page * page; while (n) { page = rb_entry(n, struct page, rb_page_cache); if (offset < page->offset) n = n->rb_left; else if (offset > page->offset) n = n->rb_right; else return page; } return NULL; } static inline struct page * __rb_insert_page_cache(struct inode * inode, unsigned long offset, rb_node_t * node) { rb_node_t ** p = &inode->i_rb_page_cache.rb_node; rb_node_t * parent = NULL; struct page * page; while (*p) { parent = *p; page = rb_entry(parent, struct page, rb_page_cache); if (offset < page->offset) p = &(*p)->rb_left; else if (offset > page->offset) p = &(*p)->rb_right; else return page; } rb_link_node(node, parent, p); return NULL; } static inline struct page * rb_insert_page_cache(struct inode * inode, unsigned long offset, rb_node_t * node) { struct page * ret; if ((ret = __rb_insert_page_cache(inode, offset, node))) goto out; rb_insert_color(node, &inode->i_rb_page_cache); out: return ret; } ----------------------------------------------------------------------- */ #ifndef _LINUX_RBTREE_H #define _LINUX_RBTREE_H #pragma pack(1) struct rb_node { struct rb_node *rb_parent; struct rb_node *rb_right; struct rb_node *rb_left; char rb_color; #define RB_RED 0 #define RB_BLACK 1 }; #pragma pack() struct rb_root { struct rb_node *rb_node; }; #define RB_ROOT (struct rb_root){ (struct rb_node *)0, } #define rb_entry(ptr, type, member) \ ((type *)((char *)(ptr)-(unsigned long)(&((type *)0)->member))) #ifdef __cplusplus extern "C" { #endif extern void rb_insert_color(struct rb_node *node, struct rb_root *root); extern void rb_erase(struct rb_node *node, struct rb_root *root); /* Find logical next and previous nodes in a tree */ extern struct rb_node *rb_next(struct rb_node *); extern struct rb_node *rb_prev(struct rb_node *); extern struct rb_node *rb_first(struct rb_root *); extern struct rb_node *rb_last(struct rb_root *); /* Fast replacement of a single node without remove/rebalance/add/rebalance */ extern void rb_replace_node(struct rb_node *victim, struct rb_node *newnode, struct rb_root *root); #ifdef __cplusplus } #endif static inline void rb_link_node(struct rb_node *node, struct rb_node *parent, struct rb_node **link) { node->rb_parent = parent; node->rb_color = RB_RED; node->rb_left = node->rb_right = (struct rb_node *)0; *link = node; } #endif workflow-0.11.8/src/kernel/thrdpool.c000066400000000000000000000137561476003635400175450ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include "msgqueue.h" #include "thrdpool.h" struct __thrdpool { msgqueue_t *msgqueue; size_t nthreads; size_t stacksize; pthread_t tid; pthread_mutex_t mutex; pthread_key_t key; pthread_cond_t *terminate; }; struct __thrdpool_task_entry { void *link; struct thrdpool_task task; }; static pthread_t __zero_tid; static void __thrdpool_exit_routine(void *context) { thrdpool_t *pool = (thrdpool_t *)context; pthread_t tid; /* One thread joins another. Don't need to keep all thread IDs. */ pthread_mutex_lock(&pool->mutex); tid = pool->tid; pool->tid = pthread_self(); if (--pool->nthreads == 0 && pool->terminate) pthread_cond_signal(pool->terminate); pthread_mutex_unlock(&pool->mutex); if (!pthread_equal(tid, __zero_tid)) pthread_join(tid, NULL); pthread_exit(NULL); } static void *__thrdpool_routine(void *arg) { thrdpool_t *pool = (thrdpool_t *)arg; struct __thrdpool_task_entry *entry; void (*task_routine)(void *); void *task_context; pthread_setspecific(pool->key, pool); while (!pool->terminate) { entry = (struct __thrdpool_task_entry *)msgqueue_get(pool->msgqueue); if (!entry) break; task_routine = entry->task.routine; task_context = entry->task.context; free(entry); task_routine(task_context); if (pool->nthreads == 0) { /* Thread pool was destroyed by the task. */ free(pool); return NULL; } } __thrdpool_exit_routine(pool); return NULL; } static void __thrdpool_terminate(int in_pool, thrdpool_t *pool) { pthread_cond_t term = PTHREAD_COND_INITIALIZER; pthread_mutex_lock(&pool->mutex); msgqueue_set_nonblock(pool->msgqueue); pool->terminate = &term; if (in_pool) { /* Thread pool destroyed in a pool thread is legal. */ pthread_detach(pthread_self()); pool->nthreads--; } while (pool->nthreads > 0) pthread_cond_wait(&term, &pool->mutex); pthread_mutex_unlock(&pool->mutex); if (!pthread_equal(pool->tid, __zero_tid)) pthread_join(pool->tid, NULL); } static int __thrdpool_create_threads(size_t nthreads, thrdpool_t *pool) { pthread_attr_t attr; pthread_t tid; int ret; ret = pthread_attr_init(&attr); if (ret == 0) { if (pool->stacksize) pthread_attr_setstacksize(&attr, pool->stacksize); while (pool->nthreads < nthreads) { ret = pthread_create(&tid, &attr, __thrdpool_routine, pool); if (ret == 0) pool->nthreads++; else break; } pthread_attr_destroy(&attr); if (pool->nthreads == nthreads) return 0; __thrdpool_terminate(0, pool); } errno = ret; return -1; } thrdpool_t *thrdpool_create(size_t nthreads, size_t stacksize) { thrdpool_t *pool; int ret; pool = (thrdpool_t *)malloc(sizeof (thrdpool_t)); if (!pool) return NULL; pool->msgqueue = msgqueue_create(0, 0); if (pool->msgqueue) { ret = pthread_mutex_init(&pool->mutex, NULL); if (ret == 0) { ret = pthread_key_create(&pool->key, NULL); if (ret == 0) { pool->stacksize = stacksize; pool->nthreads = 0; pool->tid = __zero_tid; pool->terminate = NULL; if (__thrdpool_create_threads(nthreads, pool) >= 0) return pool; pthread_key_delete(pool->key); } pthread_mutex_destroy(&pool->mutex); } errno = ret; msgqueue_destroy(pool->msgqueue); } free(pool); return NULL; } inline void __thrdpool_schedule(const struct thrdpool_task *task, void *buf, thrdpool_t *pool); void __thrdpool_schedule(const struct thrdpool_task *task, void *buf, thrdpool_t *pool) { ((struct __thrdpool_task_entry *)buf)->task = *task; msgqueue_put(buf, pool->msgqueue); } int thrdpool_schedule(const struct thrdpool_task *task, thrdpool_t *pool) { void *buf = malloc(sizeof (struct __thrdpool_task_entry)); if (buf) { __thrdpool_schedule(task, buf, pool); return 0; } return -1; } inline int thrdpool_in_pool(thrdpool_t *pool); int thrdpool_in_pool(thrdpool_t *pool) { return pthread_getspecific(pool->key) == pool; } int thrdpool_increase(thrdpool_t *pool) { pthread_attr_t attr; pthread_t tid; int ret; ret = pthread_attr_init(&attr); if (ret == 0) { if (pool->stacksize) pthread_attr_setstacksize(&attr, pool->stacksize); pthread_mutex_lock(&pool->mutex); ret = pthread_create(&tid, &attr, __thrdpool_routine, pool); if (ret == 0) pool->nthreads++; pthread_mutex_unlock(&pool->mutex); pthread_attr_destroy(&attr); if (ret == 0) return 0; } errno = ret; return -1; } int thrdpool_decrease(thrdpool_t *pool) { void *buf = malloc(sizeof (struct __thrdpool_task_entry)); struct __thrdpool_task_entry *entry; if (buf) { entry = (struct __thrdpool_task_entry *)buf; entry->task.routine = __thrdpool_exit_routine; entry->task.context = pool; msgqueue_put_head(entry, pool->msgqueue); return 0; } return -1; } void thrdpool_exit(thrdpool_t *pool) { if (thrdpool_in_pool(pool)) __thrdpool_exit_routine(pool); } void thrdpool_destroy(void (*pending)(const struct thrdpool_task *), thrdpool_t *pool) { int in_pool = thrdpool_in_pool(pool); struct __thrdpool_task_entry *entry; __thrdpool_terminate(in_pool, pool); while (1) { entry = (struct __thrdpool_task_entry *)msgqueue_get(pool->msgqueue); if (!entry) break; if (pending && entry->task.routine != __thrdpool_exit_routine) pending(&entry->task); free(entry); } pthread_key_delete(pool->key); pthread_mutex_destroy(&pool->mutex); msgqueue_destroy(pool->msgqueue); if (!in_pool) free(pool); } workflow-0.11.8/src/kernel/thrdpool.h000066400000000000000000000033541476003635400175430ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _THRDPOOL_H_ #define _THRDPOOL_H_ #include typedef struct __thrdpool thrdpool_t; struct thrdpool_task { void (*routine)(void *); void *context; }; #ifdef __cplusplus extern "C" { #endif /* * Thread pool originates from project Sogou C++ Workflow * https://github.com/sogou/workflow * * A thread task can be scheduled by another task, which is very important, * even if the pool is being destroyed. Because thread task is hard to know * what's happening to the pool. * The thread pool can also be destroyed by a thread task. This may sound * strange, but it's very logical. Destroying thread pool in thread task * does not end the task thread. It'll run till the end of task. */ thrdpool_t *thrdpool_create(size_t nthreads, size_t stacksize); int thrdpool_schedule(const struct thrdpool_task *task, thrdpool_t *pool); int thrdpool_in_pool(thrdpool_t *pool); int thrdpool_increase(thrdpool_t *pool); int thrdpool_decrease(thrdpool_t *pool); void thrdpool_exit(thrdpool_t *pool); void thrdpool_destroy(void (*pending)(const struct thrdpool_task *), thrdpool_t *pool); #ifdef __cplusplus } #endif #endif workflow-0.11.8/src/kernel/xmake.lua000066400000000000000000000003401476003635400173370ustar00rootroot00000000000000target("kernel") set_kind("object") add_files("*.cc") add_files("*.c") if is_plat("linux", "android") then remove_files("IOService_thread.cc") else remove_files("IOService_linux.cc") end workflow-0.11.8/src/manager/000077500000000000000000000000001476003635400156645ustar00rootroot00000000000000workflow-0.11.8/src/manager/CMakeLists.txt000066400000000000000000000003521476003635400204240ustar00rootroot00000000000000cmake_minimum_required(VERSION 3.6) project(manager) set(SRC DnsCache.cc RouteManager.cc WFGlobal.cc ) if (NOT UPSTREAM STREQUAL "n") set(SRC ${SRC} UpstreamManager.cc ) endif () add_library(${PROJECT_NAME} OBJECT ${SRC}) workflow-0.11.8/src/manager/DnsCache.cc000066400000000000000000000053161476003635400176500ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) Xie Han (xiehan@sogou-inc.com) */ #include #include #include "DnsCache.h" #define GET_CURRENT_SECOND std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()).count() #define TTL_INC 5 const DnsCache::DnsHandle *DnsCache::get_inner(const HostPort& host_port, int type) { int64_t cur = GET_CURRENT_SECOND; std::lock_guard lock(mutex_); const DnsHandle *handle = cache_pool_.get(host_port); if (handle && ((type == GET_TYPE_TTL && cur > handle->value.expire_time) || (type == GET_TYPE_CONFIDENT && cur > handle->value.confident_time))) { if (!handle->value.delayed()) { DnsHandle *h = const_cast(handle); if (type == GET_TYPE_TTL) h->value.expire_time += TTL_INC; else h->value.confident_time += TTL_INC; h->value.addrinfo->ai_flags |= 2; } cache_pool_.release(handle); return NULL; } return handle; } const DnsCache::DnsHandle *DnsCache::put(const HostPort& host_port, struct addrinfo *addrinfo, unsigned int dns_ttl_default, unsigned int dns_ttl_min) { int64_t expire_time; int64_t confident_time; int64_t cur_time = GET_CURRENT_SECOND; if (dns_ttl_min > dns_ttl_default) dns_ttl_min = dns_ttl_default; if (dns_ttl_min == (unsigned int)-1) confident_time = INT64_MAX; else confident_time = cur_time + dns_ttl_min; if (dns_ttl_default == (unsigned int)-1) expire_time = INT64_MAX; else expire_time = cur_time + dns_ttl_default; std::lock_guard lock(mutex_); return cache_pool_.put(host_port, {addrinfo, confident_time, expire_time}); } const DnsCache::DnsHandle *DnsCache::get(const DnsCache::HostPort& host_port) { std::lock_guard lock(mutex_); return cache_pool_.get(host_port); } void DnsCache::release(const DnsCache::DnsHandle *handle) { std::lock_guard lock(mutex_); cache_pool_.release(handle); } void DnsCache::del(const DnsCache::HostPort& key) { std::lock_guard lock(mutex_); cache_pool_.del(key); } DnsCache::DnsCache() { } DnsCache::~DnsCache() { } workflow-0.11.8/src/manager/DnsCache.h000066400000000000000000000075331476003635400175150ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #ifndef _DNSCACHE_H_ #define _DNSCACHE_H_ #include #include #include #include #include #include "LRUCache.h" #include "DnsUtil.h" #define GET_TYPE_TTL 0 #define GET_TYPE_CONFIDENT 1 struct DnsCacheValue { struct addrinfo *addrinfo; int64_t confident_time; int64_t expire_time; bool delayed() const { return addrinfo->ai_flags & 2; } }; // RAII: NO. Release handle by user // Thread safety: YES // MUST call release when handle no longer used class DnsCache { public: using HostPort = std::pair; using DnsHandle = LRUHandle; public: // get handler // Need call release when handle no longer needed //Handle *get(const KEY &key); const DnsHandle *get(const HostPort& host_port); const DnsHandle *get(const std::string& host, unsigned short port) { return get(HostPort(host, port)); } const DnsHandle *get(const char *host, unsigned short port) { return get(std::string(host), port); } const DnsHandle *get_ttl(const HostPort& host_port) { return get_inner(host_port, GET_TYPE_TTL); } const DnsHandle *get_ttl(const std::string& host, unsigned short port) { return get_ttl(HostPort(host, port)); } const DnsHandle *get_ttl(const char *host, unsigned short port) { return get_ttl(std::string(host), port); } const DnsHandle *get_confident(const HostPort& host_port) { return get_inner(host_port, GET_TYPE_CONFIDENT); } const DnsHandle *get_confident(const std::string& host, unsigned short port) { return get_confident(HostPort(host, port)); } const DnsHandle *get_confident(const char *host, unsigned short port) { return get_confident(std::string(host), port); } const DnsHandle *put(const HostPort& host_port, struct addrinfo *addrinfo, unsigned int dns_ttl_default, unsigned int dns_ttl_min); const DnsHandle *put(const std::string& host, unsigned short port, struct addrinfo *addrinfo, unsigned int dns_ttl_default, unsigned int dns_ttl_min) { return put(HostPort(host, port), addrinfo, dns_ttl_default, dns_ttl_min); } const DnsHandle *put(const char *host, unsigned short port, struct addrinfo *addrinfo, unsigned int dns_ttl_default, unsigned int dns_ttl_min) { return put(std::string(host), port, addrinfo, dns_ttl_default, dns_ttl_min); } // release handle by get/put void release(const DnsHandle *handle); // delete from cache, deleter delay called when all inuse-handle release. void del(const HostPort& key); void del(const std::string& host, unsigned short port) { del(HostPort(host, port)); } void del(const char *host, unsigned short port) { del(std::string(host), port); } private: const DnsHandle *get_inner(const HostPort& host_port, int type); std::mutex mutex_; class ValueDeleter { public: void operator() (const DnsCacheValue& value) const { struct addrinfo *ai = value.addrinfo; if (ai) { if (ai->ai_flags) freeaddrinfo(ai); else protocol::DnsUtil::freeaddrinfo(ai); } } }; LRUCache cache_pool_; public: // To prevent inline calling LRUCache's constructor and deconstructor. DnsCache(); ~DnsCache(); }; #endif workflow-0.11.8/src/manager/EndpointParams.h000066400000000000000000000024441476003635400207650ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #ifndef _ENDPOINTPARAMS_H_ #define _ENDPOINTPARAMS_H_ #include #include /** * @file EndpointParams.h * @brief Network config for client task */ enum TransportType { TT_TCP, TT_UDP, TT_SCTP, TT_TCP_SSL, TT_SCTP_SSL, }; struct EndpointParams { int address_family; size_t max_connections; int connect_timeout; int response_timeout; int ssl_connect_timeout; bool use_tls_sni; }; static constexpr struct EndpointParams ENDPOINT_PARAMS_DEFAULT = { .address_family = AF_UNSPEC, .max_connections = 200, .connect_timeout = 10 * 1000, .response_timeout = 10 * 1000, .ssl_connect_timeout = 10 * 1000, .use_tls_sni = false, }; #endif workflow-0.11.8/src/manager/RouteManager.cc000066400000000000000000000277631476003635400206030ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include "list.h" #include "rbtree.h" #include "WFGlobal.h" #include "CommScheduler.h" #include "EndpointParams.h" #include "RouteManager.h" #include "StringUtil.h" #define GET_CURRENT_SECOND std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()).count() #define MTTR_SECOND 30 using RouteTargetTCP = RouteManager::RouteTarget; class RouteTargetUDP : public RouteManager::RouteTarget { private: virtual int create_connect_fd() { const struct sockaddr *addr; socklen_t addrlen; this->get_addr(&addr, &addrlen); return socket(addr->sa_family, SOCK_DGRAM, 0); } }; class RouteTargetSCTP : public RouteManager::RouteTarget { private: #ifdef IPPROTO_SCTP virtual int create_connect_fd() { const struct sockaddr *addr; socklen_t addrlen; this->get_addr(&addr, &addrlen); return socket(addr->sa_family, SOCK_STREAM, IPPROTO_SCTP); } #else virtual int create_connect_fd() { errno = EPROTONOSUPPORT; return -1; } #endif }; /* To support TLS SNI. */ class RouteTargetTCPSNI : public RouteTargetTCP { private: virtual int init_ssl(SSL *ssl) { if (SSL_set_tlsext_host_name(ssl, this->hostname.c_str()) > 0) return 0; else return -1; } private: std::string hostname; public: RouteTargetTCPSNI(const std::string& name) : hostname(name) { } }; class RouteTargetSCTPSNI : public RouteTargetSCTP { private: virtual int init_ssl(SSL *ssl) { if (SSL_set_tlsext_host_name(ssl, this->hostname.c_str()) > 0) return 0; else return -1; } private: std::string hostname; public: RouteTargetSCTPSNI(const std::string& name) : hostname(name) { } }; // protocol_name\n user\n pass\n dbname\n ai_addr ai_addrlen \n.... // struct RouteParams { enum TransportType transport_type; const struct addrinfo *addrinfo; uint64_t key; SSL_CTX *ssl_ctx; size_t max_connections; int connect_timeout; int response_timeout; int ssl_connect_timeout; bool use_tls_sni; const std::string& hostname; }; class RouteResultEntry { public: struct rb_node rb; CommSchedObject *request_object; CommSchedGroup *group; std::mutex mutex; std::vector targets; struct list_head breaker_list; uint64_t key; int nleft; int nbreak; RouteResultEntry(): request_object(NULL), group(NULL) { INIT_LIST_HEAD(&this->breaker_list); this->nleft = 0; this->nbreak = 0; } public: int init(const struct RouteParams *params); void deinit(); void notify_unavailable(RouteManager::RouteTarget *target); void notify_available(RouteManager::RouteTarget *target); void check_breaker(); private: void free_list(); RouteManager::RouteTarget *create_target(const struct RouteParams *params, const struct addrinfo *addrinfo); int add_group_targets(const struct RouteParams *params); }; struct __breaker_node { RouteManager::RouteTarget *target; int64_t timeout; struct list_head breaker_list; }; RouteManager::RouteTarget * RouteResultEntry::create_target(const struct RouteParams *params, const struct addrinfo *addr) { RouteManager::RouteTarget *target; switch (params->transport_type) { case TT_TCP_SSL: if (params->use_tls_sni) target = new RouteTargetTCPSNI(params->hostname); else case TT_TCP: target = new RouteTargetTCP(); break; case TT_UDP: target = new RouteTargetUDP(); break; case TT_SCTP_SSL: if (params->use_tls_sni) target = new RouteTargetSCTPSNI(params->hostname); else case TT_SCTP: target = new RouteTargetSCTP(); break; default: errno = EINVAL; return NULL; } if (target->init(addr->ai_addr, addr->ai_addrlen, params->ssl_ctx, params->connect_timeout, params->ssl_connect_timeout, params->response_timeout, params->max_connections) < 0) { delete target; target = NULL; } return target; } int RouteResultEntry::init(const struct RouteParams *params) { const struct addrinfo *addr = params->addrinfo; RouteManager::RouteTarget *target; if (addr == NULL)//0 { errno = EINVAL; return -1; } if (addr->ai_next == NULL)//1 { target = this->create_target(params, addr); if (target) { this->targets.push_back(target); this->request_object = target; this->key = params->key; return 0; } return -1; } this->group = new CommSchedGroup(); if (this->group->init() >= 0) { if (this->add_group_targets(params) >= 0) { this->request_object = this->group; this->key = params->key; return 0; } this->group->deinit(); } delete this->group; return -1; } int RouteResultEntry::add_group_targets(const struct RouteParams *params) { RouteManager::RouteTarget *target; const struct addrinfo *addr; for (addr = params->addrinfo; addr; addr = addr->ai_next) { target = this->create_target(params, addr); if (target) { if (this->group->add(target) >= 0) { this->targets.push_back(target); this->nleft++; continue; } target->deinit(); delete target; } for (auto *target : this->targets) { this->group->remove(target); target->deinit(); delete target; } return -1; } return 0; } void RouteResultEntry::deinit() { for (auto *target : this->targets) { if (this->group) this->group->remove(target); target->deinit(); delete target; } if (this->group) { this->group->deinit(); delete this->group; } struct list_head *pos, *tmp; __breaker_node *node; list_for_each_safe(pos, tmp, &this->breaker_list) { node = list_entry(pos, __breaker_node, breaker_list); list_del(pos); delete node; } } void RouteResultEntry::notify_unavailable(RouteManager::RouteTarget *target) { if (this->targets.size() <= 1) return; int errno_bak = errno; std::lock_guard lock(this->mutex); if (this->nleft <= 1) return; if (this->group->remove(target) < 0) { errno = errno_bak; return; } auto *node = new __breaker_node; node->target = target; node->timeout = GET_CURRENT_SECOND + MTTR_SECOND; list_add_tail(&node->breaker_list, &this->breaker_list); this->nbreak++; this->nleft--; } void RouteResultEntry::notify_available(RouteManager::RouteTarget *target) { if (this->targets.size() <= 1 || this->nbreak == 0) return; int errno_bak = errno; std::lock_guard lock(this->mutex); if (this->group->add(target) == 0) this->nleft++; else errno = errno_bak; } void RouteResultEntry::check_breaker() { if (this->targets.size() <= 1 || this->nbreak == 0) return; struct list_head *pos, *tmp; __breaker_node *node; int errno_bak = errno; int64_t cur_time = GET_CURRENT_SECOND; std::lock_guard lock(this->mutex); list_for_each_safe(pos, tmp, &this->breaker_list) { node = list_entry(pos, __breaker_node, breaker_list); if (cur_time >= node->timeout) { if (this->group->add(node->target) == 0) this->nleft++; else errno = errno_bak; list_del(pos); delete node; this->nbreak--; } } } static inline int __addr_cmp(const struct addrinfo *x, const struct addrinfo *y) { //todo ai_protocol if (x->ai_addrlen == y->ai_addrlen) return memcmp(x->ai_addr, y->ai_addr, x->ai_addrlen); else if (x->ai_addrlen < y->ai_addrlen) return -1; else return 1; } static inline bool __addr_less(const struct addrinfo *x, const struct addrinfo *y) { return __addr_cmp(x, y) < 0; } static uint64_t __fnv_hash(const unsigned char *data, size_t size) { uint64_t hash = 14695981039346656037ULL; while (size) { hash ^= (const uint64_t)*data++; hash *= 1099511628211ULL; size--; } return hash; } static uint64_t __generate_key(enum TransportType type, const struct addrinfo *addrinfo, const std::string& other_info, const struct EndpointParams *ep_params, const std::string& hostname, SSL_CTX *ssl_ctx) { const int params[] = { ep_params->address_family, (int)ep_params->max_connections, ep_params->connect_timeout, ep_params->response_timeout }; std::string buf((const char *)&type, sizeof (enum TransportType)); if (!other_info.empty()) buf += other_info; buf.append((const char *)params, sizeof params); if (type == TT_TCP_SSL || type == TT_SCTP_SSL) { buf.append((const char *)&ssl_ctx, sizeof (void *)); buf.append((const char *)&ep_params->ssl_connect_timeout, sizeof (int)); if (ep_params->use_tls_sni) { buf += hostname; buf += '\n'; } } if (addrinfo->ai_next) { std::vector sorted_addr; sorted_addr.push_back(addrinfo); addrinfo = addrinfo->ai_next; do { sorted_addr.push_back(addrinfo); addrinfo = addrinfo->ai_next; } while (addrinfo); std::sort(sorted_addr.begin(), sorted_addr.end(), __addr_less); for (const struct addrinfo *p : sorted_addr) { buf.append((const char *)&p->ai_addrlen, sizeof (socklen_t)); buf.append((const char *)p->ai_addr, p->ai_addrlen); } } else { buf.append((const char *)&addrinfo->ai_addrlen, sizeof (socklen_t)); buf.append((const char *)addrinfo->ai_addr, addrinfo->ai_addrlen); } return __fnv_hash((const unsigned char *)buf.c_str(), buf.size()); } RouteManager::~RouteManager() { RouteResultEntry *entry; while (cache_.rb_node) { entry = rb_entry(cache_.rb_node, RouteResultEntry, rb); rb_erase(cache_.rb_node, &cache_); entry->deinit(); delete entry; } } int RouteManager::get(enum TransportType type, const struct addrinfo *addrinfo, const std::string& other_info, const struct EndpointParams *ep_params, const std::string& hostname, SSL_CTX *ssl_ctx, RouteResult& result) { if (type == TT_TCP_SSL || type == TT_SCTP_SSL) { static SSL_CTX *global_client_ctx = WFGlobal::get_ssl_client_ctx(); if (ssl_ctx == NULL) ssl_ctx = global_client_ctx; } else ssl_ctx = NULL; uint64_t key = __generate_key(type, addrinfo, other_info, ep_params, hostname, ssl_ctx); struct rb_node **p = &cache_.rb_node; struct rb_node *parent = NULL; RouteResultEntry *bound = NULL; RouteResultEntry *entry; std::lock_guard lock(mutex_); while (*p) { parent = *p; entry = rb_entry(*p, RouteResultEntry, rb); if (key <= entry->key) { bound = entry; p = &(*p)->rb_left; } else p = &(*p)->rb_right; } if (bound && bound->key == key) { entry = bound; entry->check_breaker(); } else { struct RouteParams params = { .transport_type = type, .addrinfo = addrinfo, .key = key, .ssl_ctx = ssl_ctx, .max_connections = ep_params->max_connections, .connect_timeout = ep_params->connect_timeout, .response_timeout = ep_params->response_timeout, .ssl_connect_timeout = ep_params->ssl_connect_timeout, .use_tls_sni = ep_params->use_tls_sni, .hostname = hostname, }; entry = new RouteResultEntry; if (entry->init(¶ms) >= 0) { rb_link_node(&entry->rb, parent, p); rb_insert_color(&entry->rb, &cache_); } else { delete entry; return -1; } } result.cookie = entry; result.request_object = entry->request_object; return 0; } void RouteManager::notify_unavailable(void *cookie, CommTarget *target) { if (cookie && target) ((RouteResultEntry *)cookie)->notify_unavailable((RouteTarget *)target); } void RouteManager::notify_available(void *cookie, CommTarget *target) { if (cookie && target) ((RouteResultEntry *)cookie)->notify_available((RouteTarget *)target); } workflow-0.11.8/src/manager/RouteManager.h000066400000000000000000000046721476003635400204370ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #ifndef _ROUTEMANAGER_H_ #define _ROUTEMANAGER_H_ #include #include #include #include #include #include #include "rbtree.h" #include "WFConnection.h" #include "EndpointParams.h" #include "CommScheduler.h" class RouteManager { public: class RouteResult { public: void *cookie; CommSchedObject *request_object; public: RouteResult(): cookie(NULL), request_object(NULL) { } void clear() { cookie = NULL; request_object = NULL; } }; class RouteTarget : public CommSchedTarget { #if OPENSSL_VERSION_NUMBER >= 0x10100000L public: int init(const struct sockaddr *addr, socklen_t addrlen, SSL_CTX *ssl_ctx, int connect_timeout, int ssl_connect_timeout, int response_timeout, size_t max_connections) { int ret = this->CommSchedTarget::init(addr, addrlen, ssl_ctx, connect_timeout, ssl_connect_timeout, response_timeout, max_connections); if (ret >= 0 && ssl_ctx) SSL_CTX_up_ref(ssl_ctx); return ret; } void deinit() { SSL_CTX *ssl_ctx = this->get_ssl_ctx(); this->CommSchedTarget::deinit(); if (ssl_ctx) SSL_CTX_free(ssl_ctx); } #endif public: int state; private: virtual WFConnection *new_connection(int connect_fd) { return new WFConnection; } public: RouteTarget() : state(0) { } }; public: int get(enum TransportType type, const struct addrinfo *addrinfo, const std::string& other_info, const struct EndpointParams *ep_params, const std::string& hostname, SSL_CTX *ssl_ctx, RouteResult& result); RouteManager() { cache_.rb_node = NULL; } ~RouteManager(); private: std::mutex mutex_; struct rb_root cache_; public: static void notify_unavailable(void *cookie, CommTarget *target); static void notify_available(void *cookie, CommTarget *target); }; #endif workflow-0.11.8/src/manager/UpstreamManager.cc000066400000000000000000000145021476003635400212700ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #include #include #include "UpstreamManager.h" #include "WFNameService.h" #include "WFGlobal.h" #include "UpstreamPolicies.h" class __UpstreamManager { public: static __UpstreamManager *get_instance() { static __UpstreamManager kInstance; return &kInstance; } void add_upstream_policy(UPSGroupPolicy *policy) { pthread_mutex_lock(&this->mutex); this->upstream_policies.push_back(policy); pthread_mutex_unlock(&this->mutex); } private: __UpstreamManager() : mutex(PTHREAD_MUTEX_INITIALIZER) { } ~__UpstreamManager() { for (UPSGroupPolicy *policy : this->upstream_policies) delete policy; } pthread_mutex_t mutex; std::vector upstream_policies; }; int UpstreamManager::upstream_create_round_robin(const std::string& name, bool try_another) { WFNameService *ns = WFGlobal::get_name_service(); auto *policy = new UPSRoundRobinPolicy(try_another); if (ns->add_policy(name.c_str(), policy) >= 0) { __UpstreamManager::get_instance()->add_upstream_policy(policy); return 0; } delete policy; return -1; } static unsigned int __default_consistent_hash(const char *path, const char *query, const char *fragment) { static std::hash std_hash; std::string str(path); str += query; str += fragment; return std_hash(str); } int UpstreamManager::upstream_create_consistent_hash(const std::string& name, upstream_route_t consistent_hash) { WFNameService *ns = WFGlobal::get_name_service(); UPSConsistentHashPolicy *policy; policy = new UPSConsistentHashPolicy( consistent_hash ? std::move(consistent_hash) : __default_consistent_hash); if (ns->add_policy(name.c_str(), policy) >= 0) { __UpstreamManager::get_instance()->add_upstream_policy(policy); return 0; } delete policy; return -1; } int UpstreamManager::upstream_create_weighted_random(const std::string& name, bool try_another) { WFNameService *ns = WFGlobal::get_name_service(); auto *policy = new UPSWeightedRandomPolicy(try_another); if (ns->add_policy(name.c_str(), policy) >= 0) { __UpstreamManager::get_instance()->add_upstream_policy(policy); return 0; } delete policy; return -1; } int UpstreamManager::upstream_create_vnswrr(const std::string& name) { WFNameService *ns = WFGlobal::get_name_service(); auto *policy = new UPSVNSWRRPolicy(); if (ns->add_policy(name.c_str(), policy) >= 0) { __UpstreamManager::get_instance()->add_upstream_policy(policy); return 0; } delete policy; return -1; } int UpstreamManager::upstream_create_manual(const std::string& name, upstream_route_t select, bool try_another, upstream_route_t consistent_hash) { WFNameService *ns = WFGlobal::get_name_service(); UPSManualPolicy *policy; policy = new UPSManualPolicy(try_another, std::move(select), consistent_hash ? std::move(consistent_hash) : __default_consistent_hash); if (ns->add_policy(name.c_str(), policy) >= 0) { __UpstreamManager::get_instance()->add_upstream_policy(policy); return 0; } delete policy; return -1; } int UpstreamManager::upstream_add_server(const std::string& name, const std::string& address) { return UpstreamManager::upstream_add_server(name, address, &ADDRESS_PARAMS_DEFAULT); } int UpstreamManager::upstream_add_server(const std::string& name, const std::string& address, const AddressParams *address_params) { WFNameService *ns = WFGlobal::get_name_service(); UPSGroupPolicy *policy = dynamic_cast(ns->get_policy(name.c_str())); if (policy) { policy->add_server(address, address_params); return 0; } errno = ENOENT; return -1; } int UpstreamManager::upstream_remove_server(const std::string& name, const std::string& address) { WFNameService *ns = WFGlobal::get_name_service(); UPSGroupPolicy *policy = dynamic_cast(ns->get_policy(name.c_str())); if (policy) return policy->remove_server(address); errno = ENOENT; return -1; } int UpstreamManager::upstream_delete(const std::string& name) { WFNameService *ns = WFGlobal::get_name_service(); UPSGroupPolicy *policy = dynamic_cast(ns->del_policy(name.c_str())); if (policy) return 0; errno = ENOENT; return -1; } std::vector UpstreamManager::upstream_main_address_list(const std::string& name) { std::vector address; WFNameService *ns = WFGlobal::get_name_service(); UPSGroupPolicy *policy = dynamic_cast(ns->get_policy(name.c_str())); if (policy) policy->get_main_address(address); return address; } int UpstreamManager::upstream_disable_server(const std::string& name, const std::string& address) { WFNameService *ns = WFGlobal::get_name_service(); UPSGroupPolicy *policy = dynamic_cast(ns->get_policy(name.c_str())); if (policy) { policy->disable_server(address); return 0; } errno = ENOENT; return -1; } int UpstreamManager::upstream_enable_server(const std::string& name, const std::string& address) { WFNameService *ns = WFGlobal::get_name_service(); UPSGroupPolicy *policy = dynamic_cast(ns->get_policy(name.c_str())); if (policy) { policy->enable_server(address); return 0; } errno = ENOENT; return -1; } int UpstreamManager::upstream_replace_server(const std::string& name, const std::string& address, const struct AddressParams *address_params) { WFNameService *ns = WFGlobal::get_name_service(); UPSGroupPolicy *policy = dynamic_cast(ns->get_policy(name.c_str())); if (policy) { policy->replace_server(address, address_params); return 0; } errno = ENOENT; return -1; } workflow-0.11.8/src/manager/UpstreamManager.h000066400000000000000000000204571476003635400211400ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #ifndef _UPSTREAM_MANAGER_H_ #define _UPSTREAM_MANAGER_H_ #include #include #include "WFServiceGovernance.h" #include "UpstreamPolicies.h" #include "WFGlobal.h" /** * @file UpstreamManager.h * @brief Local Reverse Proxy & Load Balance & Service Discovery * @details * - This is very similar with Nginx-Upstream. * - Do not cost any other network resource, We just simulate in local to choose one target properly. * - This is working only for the current process. */ /** * @brief Upstream Management Class * @details * - We support three modes: * 1. Weighted-Random * 2. Consistent-Hash * 3. Manual-Select * - Additional, we support Main-backup & Group for server and working well in any mode. * * @code{.cc} upstream_create_weighted_random("abc.sogou", true); //UPSTREAM_WEIGHTED_RANDOM upstream_add_server("abc.sogou", "192.168.2.100:8081"); //weight=1, max_fails=200 upstream_add_server("abc.sogou", "192.168.2.100:9090"); //weight=1, max_fails=200 AddressParams params = ADDRESS_PARAMS_DEFAULT; params.weight = 2; params.max_fails = 6; upstream_add_server("abc.sogou", "www.sogou.com", ¶ms); //weight=2, max_fails=6 //send request with url like http://abc.sogou/somepath/somerequest upstream_create_consistent_hash("def.sogou", [](const char *path, const char *query, const char *fragment) -> int { return somehash(...)); }); //UPSTREAM_CONSISTENT_HASH upstream_create_manual("xyz.sogou", [](const char *path, const char *query, const char *fragment) -> int { return select_xxx(...)); }, true, [](const char *path, const char *query, const char *fragment) -> int { return rehash(...)); },); //UPSTREAM_MANUAL * @endcode */ class UpstreamManager { public: /** * @brief MODE 0: round-robin select * @param[in] name upstream name * @param[in] try_another when first choice is failed, try another one or not * @return success/fail * @retval 0 success * @retval -1 fail, more info see errno * @note * when first choose server is already down: * - if try_another==false, request will be failed * - if try_another==true, upstream will choose the next */ static int upstream_create_round_robin(const std::string& name, bool try_another); /** * @brief MODE 1: consistent-hashing select * @param[in] name upstream name * @param[in] consitent_hash consistent-hash functional * @return success/fail * @retval 0 success * @retval -1 fail, more info see errno * @note consitent_hash need to return value in 0~(2^31-1) Balance/Monotonicity/Spread/Smoothness * @note if consitent_hash==nullptr, upstream will use std::hash with request uri */ static int upstream_create_consistent_hash(const std::string& name, upstream_route_t consitent_hash); /** * @brief MODE 2: weighted-random select * @param[in] name upstream name * @param[in] try_another when first choice is failed, try another one or not * @return success/fail * @retval 0 success * @retval -1 fail, more info see errno * @note * when first choose server is already down: * - if try_another==false, request will be failed * - if try_another==true, upstream will choose from alive-servers by weight-random strategy */ static int upstream_create_weighted_random(const std::string& name, bool try_another); /** * @brief MODE 3: manual select * @param[in] name upstream name * @param[in] select manual select functional, just tell us main-index. * @param[in] try_another when first choice is failed, try another one or not * @param[in] consitent_hash consistent-hash functional * @return success/fail * @retval 0 success * @retval -1 fail, more info see errno * @note * when first choose server is already down: * - if try_another==false, request will be failed, consistent_hash value will be ignored * - if try_another==true, upstream will work with consistent hash mode,if consitent_hash==NULL, upstream will use std::hash with request uri * @warning select functional cannot be nullptr! */ static int upstream_create_manual(const std::string& name, upstream_route_t select, bool try_another, upstream_route_t consitent_hash); /** * @brief MODE 4: VNSWRR select * @param[in] name upstream name * @return success/fail * @retval 0 success * @retval -1 fail, more info see errno * @note */ static int upstream_create_vnswrr(const std::string& name); /** * @brief Delete one upstream * @param[in] name upstream name * @return success/fail * @retval 0 success * @retval -1 fail, not found */ static int upstream_delete(const std::string& name); public: /** * @brief Add server into one upstream, with default config * @param[in] name upstream name * @param[in] address ip OR host OR ip:port OR host:port OR /unix-domain-socket * @return success/fail * @retval 0 success * @retval -1 fail, more info see errno * @warning Same address add twice, means two different server */ static int upstream_add_server(const std::string& name, const std::string& address); /** * @brief Add server into one upstream, with custom config * @param[in] name upstream name * @param[in] address ip OR host OR ip:port OR host:port OR /unix-domain-socket * @param[in] address_params custom config for this target server * @return success/fail * @retval 0 success * @retval -1 fail, more info see errno * @warning Same address with different params, means two different server * @warning Same address with exactly same params, still means two different server */ static int upstream_add_server(const std::string& name, const std::string& address, const struct AddressParams *address_params); /** * @brief Remove server from one upstream * @param[in] name upstream name * @param[in] address same as address when add by upstream_add_server * @return success/fail * @retval >=0 success, the amount of being removed server * @retval -1 fail, upstream name not found * @warning If server servers has the same address in this upstream, we will remove them all */ static int upstream_remove_server(const std::string& name, const std::string& address); /** * @brief get all main servers address list from one upstream * @param[in] name upstream name * @return all main servers' address list * @warning If server servers has the same address in this upstream, then will appear in the vector multiply times */ static std::vector upstream_main_address_list(const std::string& name); public: /// @breif for plugin static int upstream_disable_server(const std::string& name, const std::string& address); static int upstream_enable_server(const std::string& name, const std::string& address); static int upstream_replace_server(const std::string& name, const std::string& address, const struct AddressParams *address_params); }; #endif workflow-0.11.8/src/manager/WFFacilities.h000066400000000000000000000051001476003635400203420ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com) Xie Han (xiehan@sogou-inc.com) */ #ifndef _WFFACILITIES_H_ #define _WFFACILITIES_H_ #include #include "WFFuture.h" #include "WFTaskFactory.h" class WFFacilities { public: static void usleep(unsigned int microseconds); static WFFuture async_usleep(unsigned int microseconds); public: template static void go(const std::string& queue_name, FUNC&& func, ARGS&&... args); public: template struct WFNetworkResult { RESP resp; long long seqid; int task_state; int task_error; }; template static WFNetworkResult request(enum TransportType type, const std::string& url, REQ&& req, int retry_max); template static WFFuture> async_request(enum TransportType type, const std::string& url, REQ&& req, int retry_max); public:// async fileIO static WFFuture async_pread(int fd, void *buf, size_t count, off_t offset); static WFFuture async_pwrite(int fd, const void *buf, size_t count, off_t offset); static WFFuture async_preadv(int fd, const struct iovec *iov, int iovcnt, off_t offset); static WFFuture async_pwritev(int fd, const struct iovec *iov, int iovcnt, off_t offset); static WFFuture async_fsync(int fd); static WFFuture async_fdatasync(int fd); public: class WaitGroup { public: WaitGroup(int n); ~WaitGroup(); void done(); void wait() const; std::future_status wait(int timeout) const; private: static void __wait_group_callback(WFCounterTask *task); std::atomic nleft; WFCounterTask *task; WFFuture future; }; private: static void __timer_future_callback(WFTimerTask *task); static void __fio_future_callback(WFFileIOTask *task); static void __fvio_future_callback(WFFileVIOTask *task); static void __fsync_future_callback(WFFileSyncTask *task); }; #include "WFFacilities.inl" #endif workflow-0.11.8/src/manager/WFFacilities.inl000066400000000000000000000135661476003635400207140ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com) */ inline void WFFacilities::usleep(unsigned int microseconds) { async_usleep(microseconds).get(); } inline WFFuture WFFacilities::async_usleep(unsigned int microseconds) { auto *pr = new WFPromise(); auto fr = pr->get_future(); auto *task = WFTaskFactory::create_timer_task(microseconds, __timer_future_callback); task->user_data = pr; task->start(); return fr; } template void WFFacilities::go(const std::string& queue_name, FUNC&& func, ARGS&&... args) { WFTaskFactory::create_go_task(queue_name, std::forward(func), std::forward(args)...)->start(); } template WFFacilities::WFNetworkResult WFFacilities::request(enum TransportType type, const std::string& url, REQ&& req, int retry_max) { return async_request(type, url, std::forward(req), retry_max).get(); } template WFFuture> WFFacilities::async_request(enum TransportType type, const std::string& url, REQ&& req, int retry_max) { ParsedURI uri; auto *pr = new WFPromise>(); auto fr = pr->get_future(); auto *task = new WFComplexClientTask(retry_max, [pr](WFNetworkTask *task) { WFNetworkResult res; res.seqid = task->get_task_seq(); res.task_state = task->get_state(); res.task_error = task->get_error(); if (res.task_state == WFT_STATE_SUCCESS) res.resp = std::move(*task->get_resp()); pr->set_value(std::move(res)); delete pr; }); URIParser::parse(url, uri); task->init(std::move(uri)); task->set_transport_type(type); *task->get_req() = std::forward(req); task->start(); return fr; } inline WFFuture WFFacilities::async_pread(int fd, void *buf, size_t count, off_t offset) { auto *pr = new WFPromise(); auto fr = pr->get_future(); auto *task = WFTaskFactory::create_pread_task(fd, buf, count, offset, __fio_future_callback); task->user_data = pr; task->start(); return fr; } inline WFFuture WFFacilities::async_pwrite(int fd, const void *buf, size_t count, off_t offset) { auto *pr = new WFPromise(); auto fr = pr->get_future(); auto *task = WFTaskFactory::create_pwrite_task(fd, buf, count, offset, __fio_future_callback); task->user_data = pr; task->start(); return fr; } inline WFFuture WFFacilities::async_preadv(int fd, const struct iovec *iov, int iovcnt, off_t offset) { auto *pr = new WFPromise(); auto fr = pr->get_future(); auto *task = WFTaskFactory::create_preadv_task(fd, iov, iovcnt, offset, __fvio_future_callback); task->user_data = pr; task->start(); return fr; } inline WFFuture WFFacilities::async_pwritev(int fd, const struct iovec *iov, int iovcnt, off_t offset) { auto *pr = new WFPromise(); auto fr = pr->get_future(); auto *task = WFTaskFactory::create_pwritev_task(fd, iov, iovcnt, offset, __fvio_future_callback); task->user_data = pr; task->start(); return fr; } inline WFFuture WFFacilities::async_fsync(int fd) { auto *pr = new WFPromise(); auto fr = pr->get_future(); auto *task = WFTaskFactory::create_fsync_task(fd, __fsync_future_callback); task->user_data = pr; task->start(); return fr; } inline WFFuture WFFacilities::async_fdatasync(int fd) { auto *pr = new WFPromise(); auto fr = pr->get_future(); auto *task = WFTaskFactory::create_fdsync_task(fd, __fsync_future_callback); task->user_data = pr; task->start(); return fr; } inline void WFFacilities::__timer_future_callback(WFTimerTask *task) { auto *pr = static_cast *>(task->user_data); pr->set_value(); delete pr; } inline void WFFacilities::__fio_future_callback(WFFileIOTask *task) { auto *pr = static_cast *>(task->user_data); pr->set_value(task->get_retval()); delete pr; } inline void WFFacilities::__fvio_future_callback(WFFileVIOTask *task) { auto *pr = static_cast *>(task->user_data); pr->set_value(task->get_retval()); delete pr; } inline void WFFacilities::__fsync_future_callback(WFFileSyncTask *task) { auto *pr = static_cast *>(task->user_data); pr->set_value(task->get_retval()); delete pr; } inline WFFacilities::WaitGroup::WaitGroup(int n) : nleft(n) { if (n <= 0) { this->nleft = -1; return; } auto *pr = new WFPromise(); this->task = WFTaskFactory::create_counter_task(1, __wait_group_callback); this->future = pr->get_future(); this->task->user_data = pr; this->task->start(); } inline WFFacilities::WaitGroup::~WaitGroup() { if (this->nleft > 0) this->task->count(); } inline void WFFacilities::WaitGroup::done() { if (--this->nleft == 0) { this->task->count(); } } inline void WFFacilities::WaitGroup::wait() const { if (this->nleft < 0) return; this->future.wait(); } inline std::future_status WFFacilities::WaitGroup::wait(int timeout) const { if (this->nleft < 0) return std::future_status::ready; if (timeout < 0) { this->future.wait(); return std::future_status::ready; } return this->future.wait_for(std::chrono::milliseconds(timeout)); } inline void WFFacilities::WaitGroup::__wait_group_callback(WFCounterTask *task) { auto *pr = static_cast *>(task->user_data); pr->set_value(); delete pr; } workflow-0.11.8/src/manager/WFFuture.h000066400000000000000000000076711476003635400175570ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) */ #ifndef _WFFUTURE_H_ #define _WFFUTURE_H_ #include #include #include #include "CommScheduler.h" #include "WFGlobal.h" template class WFFuture { public: WFFuture(std::future&& fr) : future(std::move(fr)) { } WFFuture() = default; WFFuture(const WFFuture&) = delete; WFFuture(WFFuture&& move) = default; WFFuture& operator=(const WFFuture&) = delete; WFFuture& operator=(WFFuture&& move) = default; void wait() const; template std::future_status wait_for(const std::chrono::duration& time_duration) const; template std::future_status wait_until(const std::chrono::time_point& timeout_time) const; RES get() { this->wait(); return this->future.get(); } bool valid() const { return this->future.valid(); } private: std::future future; }; template class WFPromise { public: WFPromise() = default; WFPromise(const WFPromise& promise) = delete; WFPromise(WFPromise&& move) = default; WFPromise& operator=(const WFPromise& promise) = delete; WFPromise& operator=(WFPromise&& move) = default; WFFuture get_future() { return WFFuture(this->promise.get_future()); } void set_value(const RES& value) { this->promise.set_value(value); } void set_value(RES&& value) { this->promise.set_value(std::move(value)); } private: std::promise promise; }; template void WFFuture::wait() const { if (this->future.wait_for(std::chrono::seconds(0)) != std::future_status::ready) { int cookie = WFGlobal::sync_operation_begin(); this->future.wait(); WFGlobal::sync_operation_end(cookie); } } template template std::future_status WFFuture::wait_for(const std::chrono::duration& time_duration) const { std::future_status status = std::future_status::ready; if (this->future.wait_for(std::chrono::seconds(0)) != std::future_status::ready) { int cookie = WFGlobal::sync_operation_begin(); status = this->future.wait_for(time_duration); WFGlobal::sync_operation_end(cookie); } return status; } template template std::future_status WFFuture::wait_until(const std::chrono::time_point& timeout_time) const { std::future_status status = std::future_status::ready; if (this->future.wait_for(std::chrono::seconds(0)) != std::future_status::ready) { int cookie = WFGlobal::sync_operation_begin(); status = this->future.wait_until(timeout_time); WFGlobal::sync_operation_end(cookie); } return status; } ///// WFFuture template specialization template<> inline void WFFuture::get() { this->wait(); this->future.get(); } template<> class WFPromise { public: WFPromise() = default; WFPromise(const WFPromise& promise) = delete; WFPromise(WFPromise&& move) = default; WFPromise& operator=(const WFPromise& promise) = delete; WFPromise& operator=(WFPromise&& move) = default; WFFuture get_future() { return WFFuture(this->promise.get_future()); } void set_value() { this->promise.set_value(); } // void set_value(const RES& value) { this->promise.set_value(value); } // void set_value(RES&& value) { this->promise.set_value(std::move(value)); } private: std::promise promise; }; #endif workflow-0.11.8/src/manager/WFGlobal.cc000066400000000000000000000476351476003635400176470ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) Xie Han (xiehan@sogou-inc.com) Liu Kai (liukaidx@sogou-inc.com) */ #include #include #include #include #include #include #include #include #include #include #include #include #if OPENSSL_VERSION_NUMBER < 0x10100000L # include # include # include # include #endif #include "CommScheduler.h" #include "Executor.h" #include "WFResourcePool.h" #include "WFTaskError.h" #include "WFDnsClient.h" #include "WFGlobal.h" #include "URIParser.h" class __WFGlobal { public: static __WFGlobal *get_instance() { static __WFGlobal kInstance; return &kInstance; } const char *get_default_port(const std::string& scheme) { const auto it = static_scheme_port_.find(scheme); if (it != static_scheme_port_.end()) return it->second; const char *port = NULL; user_scheme_port_mutex_.lock(); const auto it2 = user_scheme_port_.find(scheme); if (it2 != user_scheme_port_.end()) port = it2->second.c_str(); user_scheme_port_mutex_.unlock(); return port; } void register_scheme_port(const std::string& scheme, unsigned short port) { user_scheme_port_mutex_.lock(); user_scheme_port_[scheme] = std::to_string(port); user_scheme_port_mutex_.unlock(); } void sync_operation_begin() { bool inc; sync_mutex_.lock(); inc = ++sync_count_ > sync_max_; if (inc) sync_max_ = sync_count_; sync_mutex_.unlock(); if (inc) WFGlobal::increase_handler_thread(); } void sync_operation_end() { int dec = 0; sync_mutex_.lock(); if (--sync_count_ < (sync_max_ + 1) / 2) { dec = sync_max_ - 2 * sync_count_; sync_max_ -= dec; } sync_mutex_.unlock(); while (dec > 0) { WFGlobal::decrease_handler_thread(); dec--; } } private: __WFGlobal(); private: std::unordered_map static_scheme_port_; std::unordered_map user_scheme_port_; std::mutex user_scheme_port_mutex_; std::mutex sync_mutex_; int sync_count_; int sync_max_; }; __WFGlobal::__WFGlobal() { static_scheme_port_["dns"] = "53"; static_scheme_port_["Dns"] = "53"; static_scheme_port_["DNS"] = "53"; static_scheme_port_["dnss"] = "853"; static_scheme_port_["Dnss"] = "853"; static_scheme_port_["DNSs"] = "853"; static_scheme_port_["DNSS"] = "853"; static_scheme_port_["http"] = "80"; static_scheme_port_["Http"] = "80"; static_scheme_port_["HTTP"] = "80"; static_scheme_port_["https"] = "443"; static_scheme_port_["Https"] = "443"; static_scheme_port_["HTTPs"] = "443"; static_scheme_port_["HTTPS"] = "443"; static_scheme_port_["redis"] = "6379"; static_scheme_port_["Redis"] = "6379"; static_scheme_port_["REDIS"] = "6379"; static_scheme_port_["rediss"] = "6379"; static_scheme_port_["Rediss"] = "6379"; static_scheme_port_["REDISs"] = "6379"; static_scheme_port_["REDISS"] = "6379"; static_scheme_port_["mysql"] = "3306"; static_scheme_port_["Mysql"] = "3306"; static_scheme_port_["MySql"] = "3306"; static_scheme_port_["MySQL"] = "3306"; static_scheme_port_["MYSQL"] = "3306"; static_scheme_port_["mysqls"] = "3306"; static_scheme_port_["Mysqls"] = "3306"; static_scheme_port_["MySqls"] = "3306"; static_scheme_port_["MySQLs"] = "3306"; static_scheme_port_["MYSQLs"] = "3306"; static_scheme_port_["MYSQLS"] = "3306"; static_scheme_port_["kafka"] = "9092"; static_scheme_port_["Kafka"] = "9092"; static_scheme_port_["KAFKA"] = "9092"; static_scheme_port_["kafkas"] = "9093"; static_scheme_port_["Kafkas"] = "9093"; static_scheme_port_["KAFKAs"] = "9093"; static_scheme_port_["KAFKAS"] = "9093"; sync_count_ = 0; sync_max_ = 0; } #if OPENSSL_VERSION_NUMBER < 0x10100000L static std::mutex *__ssl_mutex; static void ssl_locking_callback(int mode, int type, const char* file, int line) { if (mode & CRYPTO_LOCK) __ssl_mutex[type].lock(); else if (mode & CRYPTO_UNLOCK) __ssl_mutex[type].unlock(); } #endif class __SSLManager { public: static __SSLManager *get_instance() { static __SSLManager kInstance; return &kInstance; } SSL_CTX *get_ssl_client_ctx() { return ssl_client_ctx_; } SSL_CTX *new_ssl_server_ctx() { return SSL_CTX_new(SSLv23_server_method()); } private: __SSLManager() { #if OPENSSL_VERSION_NUMBER < 0x10100000L __ssl_mutex = new std::mutex[CRYPTO_num_locks()]; CRYPTO_set_locking_callback(ssl_locking_callback); SSL_library_init(); SSL_load_error_strings(); //ERR_load_crypto_strings(); //OpenSSL_add_all_algorithms(); #endif ssl_client_ctx_ = SSL_CTX_new(SSLv23_client_method()); if (ssl_client_ctx_ == NULL) abort(); } ~__SSLManager() { SSL_CTX_free(ssl_client_ctx_); #if OPENSSL_VERSION_NUMBER < 0x10100000L //free ssl to avoid memory leak FIPS_mode_set(0); CRYPTO_set_locking_callback(NULL); # ifdef CRYPTO_LOCK_ECDH CRYPTO_THREADID_set_callback(NULL); # else CRYPTO_set_id_callback(NULL); # endif ENGINE_cleanup(); CONF_modules_unload(1); ERR_free_strings(); EVP_cleanup(); # ifdef CRYPTO_LOCK_ECDH ERR_remove_thread_state(NULL); # else ERR_remove_state(0); # endif CRYPTO_cleanup_all_ex_data(); sk_SSL_COMP_free(SSL_COMP_get_compression_methods()); delete []__ssl_mutex; #endif } private: SSL_CTX *ssl_client_ctx_; }; class __FileIOService : public IOService { public: __FileIOService(CommScheduler *scheduler): scheduler_(scheduler), flag_(true) {} int bind() { mutex_.lock(); flag_ = false; int ret = scheduler_->io_bind(this); if (ret < 0) flag_ = true; mutex_.unlock(); return ret; } void deinit() { std::unique_lock lock(mutex_); while (!flag_) cond_.wait(lock); lock.unlock(); IOService::deinit(); } private: virtual void handle_unbound() { mutex_.lock(); flag_ = true; cond_.notify_one(); mutex_.unlock(); } virtual void handle_stop(int error) { scheduler_->io_unbind(this); } CommScheduler *scheduler_; std::mutex mutex_; std::condition_variable cond_; bool flag_; }; class __ThreadDnsManager { public: static __ThreadDnsManager *get_instance() { static __ThreadDnsManager kInstance; return &kInstance; } ExecQueue *get_dns_queue() { return &dns_queue_; } Executor *get_dns_executor() { return &dns_executor_; } __ThreadDnsManager() { int ret; ret = dns_queue_.init(); if (ret < 0) abort(); ret = dns_executor_.init(WFGlobal::get_global_settings()->dns_threads); if (ret < 0) abort(); } ~__ThreadDnsManager() { dns_executor_.deinit(); dns_queue_.deinit(); } private: ExecQueue dns_queue_; Executor dns_executor_; }; class __CommManager { public: static __CommManager *get_instance() { static __CommManager kInstance; __CommManager::created_ = true; return &kInstance; } CommScheduler *get_scheduler() { return &scheduler_; } IOService *get_io_service(); static bool is_created() { return created_; } private: __CommManager(): fio_service_(NULL), fio_flag_(false) { const auto *settings = WFGlobal::get_global_settings(); if (scheduler_.init(settings->poller_threads, settings->handler_threads) < 0) abort(); signal(SIGPIPE, SIG_IGN); } ~__CommManager() { // scheduler_.deinit() will triger fio_service to stop scheduler_.deinit(); if (fio_service_) { fio_service_->deinit(); delete fio_service_; } } private: CommScheduler scheduler_; __FileIOService *fio_service_; volatile bool fio_flag_; std::mutex fio_mutex_; private: static bool created_; }; bool __CommManager::created_ = false; inline IOService *__CommManager::get_io_service() { if (!fio_flag_) { fio_mutex_.lock(); if (!fio_flag_) { int maxevents = WFGlobal::get_global_settings()->fio_max_events; int n = 65536; fio_service_ = new __FileIOService(&scheduler_); while (fio_service_->init(maxevents) < 0) { if ((errno != EAGAIN && errno != EINVAL) || maxevents <= 16) abort(); while (n >= maxevents) n /= 2; maxevents = n; } if (fio_service_->bind() < 0) abort(); fio_flag_ = true; } fio_mutex_.unlock(); } return fio_service_; } class __ExecManager { protected: using ExecQueueMap = std::unordered_map; public: static __ExecManager *get_instance() { static __ExecManager kInstance; return &kInstance; } ExecQueue *get_exec_queue(const std::string& queue_name); Executor *get_compute_executor() { return &compute_executor_; } private: __ExecManager(): rwlock_(PTHREAD_RWLOCK_INITIALIZER) { int compute_threads = WFGlobal::get_global_settings()->compute_threads; if (compute_threads < 0) compute_threads = sysconf(_SC_NPROCESSORS_ONLN); if (compute_executor_.init(compute_threads) < 0) abort(); } ~__ExecManager() { compute_executor_.deinit(); for (auto& kv : queue_map_) { kv.second->deinit(); delete kv.second; } } private: pthread_rwlock_t rwlock_; ExecQueueMap queue_map_; Executor compute_executor_; }; inline ExecQueue *__ExecManager::get_exec_queue(const std::string& queue_name) { ExecQueue *queue = NULL; ExecQueueMap::const_iterator iter; pthread_rwlock_rdlock(&rwlock_); iter = queue_map_.find(queue_name); if (iter != queue_map_.cend()) queue = iter->second; pthread_rwlock_unlock(&rwlock_); if (queue) return queue; pthread_rwlock_wrlock(&rwlock_); iter = queue_map_.find(queue_name); if (iter == queue_map_.cend()) { queue = new ExecQueue(); if (queue->init() >= 0) queue_map_.emplace(queue_name, queue); else { delete queue; queue = NULL; } } else queue = iter->second; pthread_rwlock_unlock(&rwlock_); return queue; } static std::string __dns_server_url(const std::string& url, const struct addrinfo *hints) { std::string host; ParsedURI uri; struct addrinfo *res; struct in6_addr buf; if (strncasecmp(url.c_str(), "dns://", 6) == 0 || strncasecmp(url.c_str(), "dnss://", 7) == 0) { host = url; } else if (inet_pton(AF_INET6, url.c_str(), &buf) > 0) host = "dns://[" + url + "]"; else host = "dns://" + url; if (URIParser::parse(host, uri) == 0 && uri.host && uri.host[0]) { if (getaddrinfo(uri.host, "53", hints, &res) == 0) { freeaddrinfo(res); return host; } } return ""; } static void __split_merge_str(const char *p, bool is_nameserver, const struct addrinfo *hints, std::string& result) { const char *start; if (!isspace(*p)) return; while (1) { while (isspace(*p)) p++; start = p; while (*p && *p != '#' && *p != ';' && !isspace(*p)) p++; if (start == p) break; std::string str(start, p); if (is_nameserver) str = __dns_server_url(str, hints); if (!str.empty()) { if (!result.empty()) result.push_back(','); result.append(str); } } } static inline const char *__try_options(const char *p, const char *q, const char *r) { size_t len = strlen(r); if ((size_t)(q - p) >= len && strncmp(p, r, len) == 0) return p + len; return NULL; } static void __set_options(const char *p, int *ndots, int *attempts, bool *rotate) { const char *start; const char *opt; if (!isspace(*p)) return; while (1) { while (isspace(*p)) p++; start = p; while (*p && *p != '#' && *p != ';' && !isspace(*p)) p++; if (start == p) break; if ((opt = __try_options(start, p, "ndots:")) != NULL) *ndots = atoi(opt); else if ((opt = __try_options(start, p, "attempts:")) != NULL) *attempts = atoi(opt); else if ((opt = __try_options(start, p, "rotate")) != NULL) *rotate = true; } } static int __parse_resolv_conf(const char *path, std::string& url, std::string& search_list, int *ndots, int *attempts, bool *rotate) { size_t bufsize = 0; char *line = NULL; FILE *fp; int ret; fp = fopen(path, "r"); if (!fp) return -1; const struct WFGlobalSettings *settings = WFGlobal::get_global_settings(); struct addrinfo hints = { .ai_flags = AI_ADDRCONFIG | AI_NUMERICHOST | AI_NUMERICSERV, .ai_family = settings->dns_server_params.address_family, .ai_socktype = SOCK_STREAM, }; while ((ret = getline(&line, &bufsize, fp)) > 0) { if (strncmp(line, "nameserver", 10) == 0) __split_merge_str(line + 10, true, &hints, url); else if (strncmp(line, "search", 6) == 0) __split_merge_str(line + 6, false, &hints, search_list); else if (strncmp(line, "options", 7) == 0) __set_options(line + 7, ndots, attempts, rotate); } ret = ferror(fp) ? -1 : 0; free(line); fclose(fp); return ret; } class __DnsClientManager { public: static __DnsClientManager *get_instance() { static __DnsClientManager kInstance; return &kInstance; } public: WFDnsClient *get_dns_client() { return client_; } WFResourcePool *get_dns_respool() { return &respool_; }; private: __DnsClientManager() : respool_(WFGlobal::get_global_settings()-> dns_server_params.max_connections) { const char *path = WFGlobal::get_global_settings()->resolv_conf_path; client_ = NULL; if (path && path[0]) { int ndots = 1; int attempts = 2; bool rotate = false; std::string url; std::string search; __parse_resolv_conf(path, url, search, &ndots, &attempts, &rotate); if (url.size() == 0) url = "8.8.8.8"; client_ = new WFDnsClient; if (client_->init(url, search, ndots, attempts, rotate) >= 0) return; delete client_; client_ = NULL; } } ~__DnsClientManager() { if (client_) { client_->deinit(); delete client_; } } WFDnsClient *client_; WFResourcePool respool_; }; struct WFGlobalSettings WFGlobal::settings_ = GLOBAL_SETTINGS_DEFAULT; RouteManager WFGlobal::route_manager_; DnsCache WFGlobal::dns_cache_; WFDnsResolver WFGlobal::dns_resolver_; WFNameService WFGlobal::name_service_(&WFGlobal::dns_resolver_); bool WFGlobal::is_scheduler_created() { return __CommManager::is_created(); } CommScheduler *WFGlobal::get_scheduler() { return __CommManager::get_instance()->get_scheduler(); } SSL_CTX *WFGlobal::get_ssl_client_ctx() { return __SSLManager::get_instance()->get_ssl_client_ctx(); } SSL_CTX *WFGlobal::new_ssl_server_ctx() { return __SSLManager::get_instance()->new_ssl_server_ctx(); } ExecQueue *WFGlobal::get_exec_queue(const std::string& queue_name) { return __ExecManager::get_instance()->get_exec_queue(queue_name); } Executor *WFGlobal::get_compute_executor() { return __ExecManager::get_instance()->get_compute_executor(); } IOService *WFGlobal::get_io_service() { return __CommManager::get_instance()->get_io_service(); } ExecQueue *WFGlobal::get_dns_queue() { return __ThreadDnsManager::get_instance()->get_dns_queue(); } Executor *WFGlobal::get_dns_executor() { return __ThreadDnsManager::get_instance()->get_dns_executor(); } WFDnsClient *WFGlobal::get_dns_client() { return __DnsClientManager::get_instance()->get_dns_client(); } WFResourcePool *WFGlobal::get_dns_respool() { return __DnsClientManager::get_instance()->get_dns_respool(); } const char *WFGlobal::get_default_port(const std::string& scheme) { return __WFGlobal::get_instance()->get_default_port(scheme); } void WFGlobal::register_scheme_port(const std::string& scheme, unsigned short port) { __WFGlobal::get_instance()->register_scheme_port(scheme, port); } int WFGlobal::sync_operation_begin() { if (WFGlobal::is_scheduler_created() && WFGlobal::get_scheduler()->is_handler_thread()) { __WFGlobal::get_instance()->sync_operation_begin(); return 1; } return 0; } void WFGlobal::sync_operation_end(int cookie) { if (cookie) __WFGlobal::get_instance()->sync_operation_end(); } static inline const char *__get_ssl_error_string(int error) { switch (error) { case SSL_ERROR_NONE: return "SSL Error None"; case SSL_ERROR_ZERO_RETURN: return "SSL Error Zero Return"; case SSL_ERROR_WANT_READ: return "SSL Error Want Read"; case SSL_ERROR_WANT_WRITE: return "SSL Error Want Write"; case SSL_ERROR_WANT_CONNECT: return "SSL Error Want Connect"; case SSL_ERROR_WANT_ACCEPT: return "SSL Error Want Accept"; case SSL_ERROR_WANT_X509_LOOKUP: return "SSL Error Want X509 Lookup"; #ifdef SSL_ERROR_WANT_ASYNC case SSL_ERROR_WANT_ASYNC: return "SSL Error Want Async"; #endif #ifdef SSL_ERROR_WANT_ASYNC_JOB case SSL_ERROR_WANT_ASYNC_JOB: return "SSL Error Want Async Job"; #endif #ifdef SSL_ERROR_WANT_CLIENT_HELLO_CB case SSL_ERROR_WANT_CLIENT_HELLO_CB: return "SSL Error Want Client Hello CB"; #endif case SSL_ERROR_SYSCALL: return "SSL System Error"; case SSL_ERROR_SSL: return "SSL Error SSL"; default: break; } return "Unknown"; } static inline const char *__get_task_error_string(int error) { switch (error) { case WFT_ERR_URI_PARSE_FAILED: return "URI Parse Failed"; case WFT_ERR_URI_SCHEME_INVALID: return "URI Scheme Invalid"; case WFT_ERR_URI_PORT_INVALID: return "URI Port Invalid"; case WFT_ERR_UPSTREAM_UNAVAILABLE: return "Upstream Unavailable"; case WFT_ERR_HTTP_BAD_REDIRECT_HEADER: return "Http Bad Redirect Header"; case WFT_ERR_HTTP_PROXY_CONNECT_FAILED: return "Http Proxy Connect Failed"; case WFT_ERR_REDIS_ACCESS_DENIED: return "Redis Access Denied"; case WFT_ERR_REDIS_COMMAND_DISALLOWED: return "Redis Command Disallowed"; case WFT_ERR_MYSQL_HOST_NOT_ALLOWED: return "MySQL Host Not Allowed"; case WFT_ERR_MYSQL_ACCESS_DENIED: return "MySQL Access Denied"; case WFT_ERR_MYSQL_INVALID_CHARACTER_SET: return "MySQL Invalid Character Set"; case WFT_ERR_MYSQL_COMMAND_DISALLOWED: return "MySQL Command Disallowed"; case WFT_ERR_MYSQL_QUERY_NOT_SET: return "MySQL Query Not Set"; case WFT_ERR_MYSQL_SSL_NOT_SUPPORTED: return "MySQL SSL Not Supported"; case WFT_ERR_KAFKA_PARSE_RESPONSE_FAILED: return "Kafka parse response failed"; case WFT_ERR_KAFKA_PRODUCE_FAILED: return "Kafka produce api failed"; case WFT_ERR_KAFKA_FETCH_FAILED: return "Kafka fetch api failed"; case WFT_ERR_KAFKA_CGROUP_FAILED: return "Kafka cgroup failed"; case WFT_ERR_KAFKA_COMMIT_FAILED: return "Kafka commit api failed"; case WFT_ERR_KAFKA_META_FAILED: return "Kafka meta api failed"; case WFT_ERR_KAFKA_LEAVEGROUP_FAILED: return "Kafka leavegroup failed"; case WFT_ERR_KAFKA_API_UNKNOWN: return "Kafka api type unknown"; case WFT_ERR_KAFKA_VERSION_DISALLOWED: return "Kafka broker version not supported"; case WFT_ERR_KAFKA_SASL_DISALLOWED: return "Kafka sasl disallowed"; case WFT_ERR_KAFKA_ARRANGE_FAILED: return "Kafka arrange failed"; case WFT_ERR_KAFKA_LIST_OFFSETS_FAILED: return "Kafka list offsets failed"; case WFT_ERR_KAFKA_CGROUP_ASSIGN_FAILED: return "Kafka cgroup assign failed"; case WFT_ERR_CONSUL_API_UNKNOWN: return "Consul api type unknown"; case WFT_ERR_CONSUL_CHECK_RESPONSE_FAILED: return "Consul check response failed"; default: break; } return "Unknown"; } const char *WFGlobal::get_error_string(int state, int error) { switch (state) { case WFT_STATE_SUCCESS: return "Success"; case WFT_STATE_TOREPLY: return "To Reply"; case WFT_STATE_NOREPLY: return "No Reply"; case WFT_STATE_SYS_ERROR: return strerror(error); case WFT_STATE_SSL_ERROR: return __get_ssl_error_string(error); case WFT_STATE_DNS_ERROR: return gai_strerror(error); case WFT_STATE_TASK_ERROR: return __get_task_error_string(error); case WFT_STATE_ABORTED: return "Aborted"; case WFT_STATE_UNDEFINED: return "Undefined"; default: break; } return "Unknown"; } void WORKFLOW_library_init(const struct WFGlobalSettings *settings) { WFGlobal::set_global_settings(settings); } workflow-0.11.8/src/manager/WFGlobal.h000066400000000000000000000122021476003635400174670ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #ifndef _WFGLOBAL_H_ #define _WFGLOBAL_H_ #if __cplusplus < 201100 #error CPLUSPLUS VERSION required at least C++11. Please use "-std=c++11". #include #endif #include #include #include "CommScheduler.h" #include "DnsCache.h" #include "RouteManager.h" #include "Executor.h" #include "EndpointParams.h" #include "WFResourcePool.h" #include "WFNameService.h" #include "WFDnsResolver.h" /** * @file WFGlobal.h * @brief Workflow Global Settings & Workflow Global APIs */ /** * @brief Workflow Library Global Setting * @details * If you want set different settings with default, please call WORKFLOW_library_init at the beginning of the process */ struct WFGlobalSettings { struct EndpointParams endpoint_params; struct EndpointParams dns_server_params; unsigned int dns_ttl_default; ///< in seconds, DNS TTL when network request success unsigned int dns_ttl_min; ///< in seconds, DNS TTL when network request fail int dns_threads; int poller_threads; int handler_threads; int compute_threads; ///< auto-set by system CPU number if value<0 int fio_max_events; const char *resolv_conf_path; const char *hosts_path; }; /** * @brief Default Workflow Library Global Settings */ static constexpr struct WFGlobalSettings GLOBAL_SETTINGS_DEFAULT = { .endpoint_params = ENDPOINT_PARAMS_DEFAULT, .dns_server_params = ENDPOINT_PARAMS_DEFAULT, .dns_ttl_default = 3600, .dns_ttl_min = 60, .dns_threads = 4, .poller_threads = 4, .handler_threads = 20, .compute_threads = -1, .fio_max_events = 4096, .resolv_conf_path = "/etc/resolv.conf", .hosts_path = "/etc/hosts", }; /** * @brief Reset Workflow Library Global Setting * @param[in] settings custom settings pointer */ extern void WORKFLOW_library_init(const struct WFGlobalSettings *settings); /** * @brief Workflow Global Management Class * @details Workflow Global APIs */ class WFGlobal { public: /** * @brief register default port for one scheme string * @param[in] scheme scheme string * @param[in] port default port value * @warning No effect when scheme is "http"/"https"/"redis"/"rediss"/"mysql"/"kafka" */ static void register_scheme_port(const std::string& scheme, unsigned short port); /** * @brief get default port string for one scheme string * @param[in] scheme scheme string * @return port string const pointer * @retval NULL fail, scheme not found * @retval not NULL success */ static const char *get_default_port(const std::string& scheme); /** * @brief get current global settings * @return current global settings const pointer * @note returnval never NULL */ static const struct WFGlobalSettings *get_global_settings() { return &settings_; } static void set_global_settings(const struct WFGlobalSettings *settings) { settings_ = *settings; } static const char *get_error_string(int state, int error); static bool increase_handler_thread() { return WFGlobal::get_scheduler()->increase_handler_thread() == 0; } static bool decrease_handler_thread() { return WFGlobal::get_scheduler()->decrease_handler_thread() == 0; } static bool increase_compute_thread() { return WFGlobal::get_compute_executor()->increase_thread() == 0; } static bool decrease_compute_thread() { return WFGlobal::get_compute_executor()->decrease_thread() == 0; } // Internal usage only public: static bool is_scheduler_created(); static class CommScheduler *get_scheduler(); static SSL_CTX *get_ssl_client_ctx(); static SSL_CTX *new_ssl_server_ctx(); static class ExecQueue *get_exec_queue(const std::string& queue_name); static class Executor *get_compute_executor(); static class IOService *get_io_service(); static class ExecQueue *get_dns_queue(); static class Executor *get_dns_executor(); static class WFDnsClient *get_dns_client(); static class WFResourcePool *get_dns_respool(); static class RouteManager *get_route_manager() { return &route_manager_; } static class DnsCache *get_dns_cache() { return &dns_cache_; } static class WFDnsResolver *get_dns_resolver() { return &dns_resolver_; } static class WFNameService *get_name_service() { return &name_service_; } public: static int sync_operation_begin(); static void sync_operation_end(int cookie); private: static struct WFGlobalSettings settings_; static RouteManager route_manager_; static DnsCache dns_cache_; static WFDnsResolver dns_resolver_; static WFNameService name_service_; }; #endif workflow-0.11.8/src/manager/xmake.lua000066400000000000000000000002311476003635400174700ustar00rootroot00000000000000target("manager") add_files("*.cc") set_kind("object") if not has_config("upstream") then remove_files("UpstreamManager.cc") end workflow-0.11.8/src/nameservice/000077500000000000000000000000001476003635400165535ustar00rootroot00000000000000workflow-0.11.8/src/nameservice/CMakeLists.txt000066400000000000000000000004011476003635400213060ustar00rootroot00000000000000cmake_minimum_required(VERSION 3.6) project(nameservice) set(SRC WFNameService.cc WFDnsResolver.cc ) if (NOT UPSTREAM STREQUAL "n") set(SRC ${SRC} WFServiceGovernance.cc UpstreamPolicies.cc ) endif () add_library(${PROJECT_NAME} OBJECT ${SRC}) workflow-0.11.8/src/nameservice/UpstreamPolicies.cc000066400000000000000000000433331476003635400223600ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) Wang Zhulei (wangzhulei@sogou-inc.com) */ #include #include #include #include #include "rbtree.h" #include "URIParser.h" #include "UpstreamPolicies.h" class EndpointGroup { public: EndpointGroup(int group_id, UPSGroupPolicy *policy) : mutex(PTHREAD_MUTEX_INITIALIZER), gen(rand()) { this->id = group_id; this->policy = policy; this->nalives = 0; this->weight = 0; } EndpointAddress *get_one(WFNSTracing *tracing); EndpointAddress *get_one_backup(WFNSTracing *tracing); public: int id; UPSGroupPolicy *policy; struct rb_node rb; pthread_mutex_t mutex; std::mt19937 gen; std::vector mains; std::vector backups; std::atomic nalives; int weight; }; UPSAddrParams::UPSAddrParams(const struct AddressParams *params, const std::string& address) : PolicyAddrParams(params) { this->weight = params->weight; this->server_type = params->server_type; this->group_id = params->group_id; if (this->group_id < 0) this->group_id = -1; if (this->weight == 0) this->weight = 1; } void UPSGroupPolicy::get_main_address(std::vector& addr_list) { UPSAddrParams *params; pthread_rwlock_rdlock(&this->rwlock); for (const EndpointAddress *server : this->servers) { params = static_cast(server->params); if (params->server_type == 0) addr_list.push_back(server->address); } pthread_rwlock_unlock(&this->rwlock); } UPSGroupPolicy::UPSGroupPolicy() { this->group_map.rb_node = NULL; this->default_group = new EndpointGroup(-1, this); rb_link_node(&this->default_group->rb, NULL, &this->group_map.rb_node); rb_insert_color(&this->default_group->rb, &this->group_map); } UPSGroupPolicy::~UPSGroupPolicy() { EndpointGroup *group; while (this->group_map.rb_node) { group = rb_entry(this->group_map.rb_node, EndpointGroup, rb); rb_erase(this->group_map.rb_node, &this->group_map); delete group; } } inline bool UPSGroupPolicy::is_alive(const EndpointAddress *addr) const { UPSAddrParams *params = static_cast(addr->params); return ((params->group_id < 0 && addr->fail_count < addr->params->max_fails) || (params->group_id >= 0 && params->group->nalives > 0)); } void UPSGroupPolicy::recover_one_server(const EndpointAddress *addr) { this->nalives++; UPSAddrParams *params = static_cast(addr->params); params->group->nalives++; } void UPSGroupPolicy::fuse_one_server(const EndpointAddress *addr) { this->nalives--; UPSAddrParams *params = static_cast(addr->params); params->group->nalives--; } void UPSGroupPolicy::add_server(const std::string& address, const AddressParams *params) { EndpointAddress *addr = new EndpointAddress(address, new UPSAddrParams(params, address)); pthread_rwlock_wrlock(&this->rwlock); this->add_server_locked(addr); pthread_rwlock_unlock(&this->rwlock); } int UPSGroupPolicy::replace_server(const std::string& address, const AddressParams *params) { int ret; EndpointAddress *addr = new EndpointAddress(address, new UPSAddrParams(params, address)); pthread_rwlock_wrlock(&this->rwlock); ret = this->remove_server_locked(address); this->add_server_locked(addr); pthread_rwlock_unlock(&this->rwlock); return ret; } bool UPSGroupPolicy::select(const ParsedURI& uri, WFNSTracing *tracing, EndpointAddress **addr) { pthread_rwlock_rdlock(&this->rwlock); unsigned int n = (unsigned int)this->servers.size(); if (n == 0) { pthread_rwlock_unlock(&this->rwlock); return false; } this->check_breaker(); // select_addr == NULL will happen only in consistent_hash EndpointAddress *select_addr = this->first_strategy(uri, tracing); if (!select_addr || select_addr->fail_count >= select_addr->params->max_fails) { if (select_addr) select_addr = this->check_and_get(select_addr, true, tracing); if (!select_addr && this->try_another) select_addr = this->another_strategy(uri, tracing); } if (!select_addr) select_addr = this->default_group->get_one_backup(tracing); if (select_addr) { *addr = select_addr; ++select_addr->ref; } pthread_rwlock_unlock(&this->rwlock); return !!select_addr; } /* * addr_failed true: return an available one. If not exists, return NULL. * false: means addr maybe group-alive. * If addr is not available, get one from addr->group. */ EndpointAddress *UPSGroupPolicy::check_and_get(EndpointAddress *addr, bool addr_failed, WFNSTracing *tracing) { UPSAddrParams *params = static_cast(addr->params); if (addr_failed) // means fail_count >= max_fails { if (params->group_id == -1) return NULL; return params->group->get_one(tracing); } if (addr && addr->fail_count >= addr->params->max_fails && params->group_id >= 0) { EndpointAddress *tmp = params->group->get_one(tracing); if (tmp) addr = tmp; } return addr; } EndpointAddress *EndpointGroup::get_one(WFNSTracing *tracing) { if (this->nalives == 0) return NULL; EndpointAddress *server; EndpointAddress *addr = NULL; pthread_mutex_lock(&this->mutex); std::shuffle(this->mains.begin(), this->mains.end(), this->gen); for (size_t i = 0; i < this->mains.size(); i++) { server = this->mains[i]; if (server->fail_count < server->params->max_fails && WFServiceGovernance::in_select_history(tracing, server) == false) { addr = server; break; } } if (!addr) { std::shuffle(this->backups.begin(), this->backups.end(), this->gen); for (size_t i = 0; i < this->backups.size(); i++) { server = this->backups[i]; if (server->fail_count < server->params->max_fails && WFServiceGovernance::in_select_history(tracing, server) == false) { addr = server; break; } } } pthread_mutex_unlock(&this->mutex); return addr; } EndpointAddress *EndpointGroup::get_one_backup(WFNSTracing *tracing) { if (this->nalives == 0) return NULL; EndpointAddress *server; EndpointAddress *addr = NULL; pthread_mutex_lock(&this->mutex); std::shuffle(this->backups.begin(), this->backups.end(), this->gen); for (size_t i = 0; i < this->backups.size(); i++) { server = this->backups[i]; if (server->fail_count < server->params->max_fails && WFServiceGovernance::in_select_history(tracing, server) == false) { addr = server; break; } } pthread_mutex_unlock(&this->mutex); return addr; } void UPSGroupPolicy::add_server_locked(EndpointAddress *addr) { UPSAddrParams *params = static_cast(addr->params); int group_id = params->group_id; rb_node **p = &this->group_map.rb_node; rb_node *parent = NULL; EndpointGroup *group; this->server_map[addr->address].push_back(addr); if (params->server_type == 0) this->servers.push_back(addr); while (*p) { parent = *p; group = rb_entry(*p, EndpointGroup, rb); if (group_id < group->id) p = &(*p)->rb_left; else if (group_id > group->id) p = &(*p)->rb_right; else break; } if (*p == NULL) { group = new EndpointGroup(group_id, this); rb_link_node(&group->rb, parent, p); rb_insert_color(&group->rb, &this->group_map); } pthread_mutex_lock(&group->mutex); params->group = group; this->recover_one_server(addr); if (params->server_type == 0) { group->mains.push_back(addr); group->weight += params->weight; } else group->backups.push_back(addr); pthread_mutex_unlock(&group->mutex); } int UPSGroupPolicy::remove_server_locked(const std::string& address) { const auto map_it = this->server_map.find(address); size_t n = this->servers.size(); size_t new_n = 0; int ret = 0; for (size_t i = 0; i < n; i++) { if (this->servers[i]->address != address) this->servers[new_n++] = this->servers[i]; } this->servers.resize(new_n); if (map_it != this->server_map.cend()) { for (EndpointAddress *addr : map_it->second) { UPSAddrParams *params = static_cast(addr->params); EndpointGroup *group = params->group; std::vector *vec; if (params->server_type == 0) vec = &group->mains; else vec = &group->backups; pthread_mutex_lock(&group->mutex); if (params->server_type == 0) group->weight -= params->weight; for (auto it = vec->begin(); it != vec->end(); ++it) { if (*it == addr) { vec->erase(it); break; } } if (--addr->ref == 0) { this->pre_delete_server(addr); delete addr; } pthread_mutex_unlock(&group->mutex); ret++; } this->server_map.erase(map_it); } return ret; } EndpointAddress *UPSGroupPolicy::consistent_hash_with_group(unsigned int hash, WFNSTracing *tracing) { if (this->nalives == 0) return NULL; std::map::iterator it; it = this->addr_hash.lower_bound(hash); if (it == this->addr_hash.end()) it = this->addr_hash.begin(); while (!this->is_alive(it->second)) { it++; if (it == this->addr_hash.end()) it = this->addr_hash.begin(); } return this->check_and_get(it->second, false, tracing); } #define VIRTUAL_GROUP_SIZE 16 void UPSGroupPolicy::hash_map_add_addr(EndpointAddress *addr) { UPSAddrParams *params = static_cast(addr->params); if (params->server_type == 0) { static std::hash std_hash; unsigned int hash_value; size_t ip_count = this->server_map[addr->address].size(); for (int i = 0; i < VIRTUAL_GROUP_SIZE * params->weight; i++) { hash_value = std_hash(addr->address + "|v" + std::to_string(i) + "|n" + std::to_string(ip_count)); this->addr_hash.insert(std::make_pair(hash_value, addr)); } } } void UPSGroupPolicy::hash_map_remove_addr(const std::string& address) { std::map::iterator it; for (it = this->addr_hash.begin(); it != this->addr_hash.end();) { if (it->second->address == address) this->addr_hash.erase(it++); else it++; } } int UPSRoundRobinPolicy::remove_server_locked(const std::string& address) { if (servers.size() != 0) { size_t cur_idx = this->cur_idx % servers.size(); for (size_t i = 0; i < cur_idx; i++) { if (this->servers[i]->address == address) this->cur_idx--; } } return UPSGroupPolicy::remove_server_locked(address); } EndpointAddress *UPSRoundRobinPolicy::first_strategy(const ParsedURI& uri, WFNSTracing *tracing) { return this->servers[this->cur_idx++ % this->servers.size()]; } EndpointAddress *UPSRoundRobinPolicy::another_strategy(const ParsedURI& uri, WFNSTracing *tracing) { EndpointAddress *addr = this->servers[this->cur_idx++ % this->servers.size()]; return this->check_and_get(addr, false, tracing); } void UPSWeightedRandomPolicy::add_server_locked(EndpointAddress *addr) { UPSAddrParams *params = static_cast(addr->params); UPSGroupPolicy::add_server_locked(addr); if (params->server_type == 0) this->total_weight += params->weight; } int UPSWeightedRandomPolicy::remove_server_locked(const std::string& address) { UPSAddrParams *params; const auto map_it = this->server_map.find(address); if (map_it != this->server_map.cend()) { for (EndpointAddress *addr : map_it->second) { params = static_cast(addr->params); if (params->server_type == 0) this->total_weight -= params->weight; } } return UPSGroupPolicy::remove_server_locked(address); } int UPSWeightedRandomPolicy::select_history_weight(WFNSTracing *tracing) { struct TracingData *tracing_data = (struct TracingData *)tracing->data; if (!tracing_data) return 0; int ret = 0; for (EndpointAddress *server : tracing_data->history) ret += ((UPSAddrParams *)server->params)->weight; return ret; } EndpointAddress *UPSWeightedRandomPolicy::first_strategy(const ParsedURI& uri, WFNSTracing *tracing) { int x = 0; int s = 0; size_t idx; UPSAddrParams *params; int temp_weight = this->total_weight; temp_weight -= UPSWeightedRandomPolicy::select_history_weight(tracing); if (temp_weight > 0) x = rand() % temp_weight; for (idx = 0; idx < this->servers.size(); idx++) { if (WFServiceGovernance::in_select_history(tracing, this->servers[idx])) continue; params = static_cast(this->servers[idx]->params); s += params->weight; if (s > x) break; } if (idx == this->servers.size()) idx--; return this->servers[idx]; } EndpointAddress *UPSWeightedRandomPolicy::another_strategy(const ParsedURI& uri, WFNSTracing *tracing) { /* When all servers are down, recover all servers if any server * reaches fusing timeout. */ if (this->available_weight == 0) this->try_clear_breaker(); int temp_weight = this->available_weight; if (temp_weight == 0) return NULL; UPSAddrParams *params; EndpointAddress *addr = NULL; int x = rand() % temp_weight; int s = 0; for (EndpointAddress *server : this->servers) { if (this->is_alive(server)) { addr = server; params = static_cast(server->params); s += params->weight; if (s > x) break; } } if (!addr) return NULL; return this->check_and_get(addr, false, tracing); } void UPSWeightedRandomPolicy::recover_one_server(const EndpointAddress *addr) { UPSAddrParams *params = static_cast(addr->params); this->nalives++; if (params->group->nalives++ == 0 && params->group->id > 0) this->available_weight += params->group->weight; if (params->group_id < 0 && params->server_type == 0) this->available_weight += params->weight; } void UPSWeightedRandomPolicy::fuse_one_server(const EndpointAddress *addr) { UPSAddrParams *params = static_cast(addr->params); this->nalives--; if (--params->group->nalives == 0 && params->group->id > 0) this->available_weight -= params->group->weight; if (params->group_id < 0 && params->server_type == 0) this->available_weight -= params->weight; } EndpointAddress *UPSVNSWRRPolicy::first_strategy(const ParsedURI& uri, WFNSTracing *tracing) { int idx = this->cur_idx.fetch_add(1); int pos = 0; for (int i = 0; i < this->total_weight; i++, idx++) { pos = this->pre_generated_vec[idx % this->pre_generated_vec.size()]; if (WFServiceGovernance::in_select_history(tracing, this->servers[pos])) continue; break; } return this->servers[pos]; } void UPSVNSWRRPolicy::init_virtual_nodes() { UPSAddrParams *params; size_t start_pos = this->pre_generated_vec.size(); size_t end_pos = this->total_weight; this->pre_generated_vec.resize(end_pos); for (size_t i = start_pos; i < end_pos; i++) { for (size_t j = 0; j < this->servers.size(); j++) { const EndpointAddress *server = this->servers[j]; params = static_cast(server->params); this->current_weight_vec[j] += params->weight; } std::vector::iterator biggest = std::max_element(this->current_weight_vec.begin(), this->current_weight_vec.end()); this->pre_generated_vec[i] = std::distance(this->current_weight_vec.begin(), biggest); this->current_weight_vec[this->pre_generated_vec[i]] -= this->total_weight; } } void UPSVNSWRRPolicy::init() { if (this->total_weight <= 0) return; this->pre_generated_vec.clear(); this->cur_idx = rand() % this->total_weight; std::vector t(this->servers.size(), 0); this->current_weight_vec.swap(t); this->init_virtual_nodes(); } void UPSVNSWRRPolicy::add_server_locked(EndpointAddress *addr) { UPSWeightedRandomPolicy::add_server_locked(addr); init(); } int UPSVNSWRRPolicy::remove_server_locked(const std::string& address) { int ret = UPSWeightedRandomPolicy::remove_server_locked(address); init(); return ret; } EndpointAddress *UPSConsistentHashPolicy::first_strategy(const ParsedURI& uri, WFNSTracing *tracing) { unsigned int hash_value = this->consistent_hash( uri.path ? uri.path : "", uri.query ? uri.query : "", uri.fragment ? uri.fragment : ""); return this->consistent_hash_with_group(hash_value, tracing); } void UPSConsistentHashPolicy::add_server_locked(EndpointAddress *addr) { UPSGroupPolicy::add_server_locked(addr); this->hash_map_add_addr(addr); } int UPSConsistentHashPolicy::remove_server_locked(const std::string& address) { this->hash_map_remove_addr(address); return UPSGroupPolicy::remove_server_locked(address); } EndpointAddress *UPSManualPolicy::first_strategy(const ParsedURI& uri, WFNSTracing *tracing) { unsigned int idx = this->manual_select(uri.path ? uri.path : "", uri.query ? uri.query : "", uri.fragment ? uri.fragment : ""); if (idx >= this->servers.size()) idx %= this->servers.size(); return this->servers[idx]; } EndpointAddress *UPSManualPolicy::another_strategy(const ParsedURI& uri, WFNSTracing *tracing) { unsigned int hash_value = this->another_select( uri.path ? uri.path : "", uri.query ? uri.query : "", uri.fragment ? uri.fragment : ""); return this->consistent_hash_with_group(hash_value, tracing); } void UPSManualPolicy::add_server_locked(EndpointAddress *addr) { UPSGroupPolicy::add_server_locked(addr); if (this->try_another) this->hash_map_add_addr(addr); } int UPSManualPolicy::remove_server_locked(const std::string& address) { if (this->try_another) this->hash_map_remove_addr(address); return UPSGroupPolicy::remove_server_locked(address); } workflow-0.11.8/src/nameservice/UpstreamPolicies.h000066400000000000000000000126701476003635400222220ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) Wang Zhulei (wangzhulei@sogou-inc.com) */ #ifndef _UPSTREAMPOLICIES_H_ #define _UPSTREAMPOLICIES_H_ #include #include #include #include #include #include "URIParser.h" #include "EndpointParams.h" #include "WFNameService.h" #include "WFServiceGovernance.h" using upstream_route_t = std::function; class EndpointGroup; class UPSGroupPolicy; class UPSAddrParams : public PolicyAddrParams { public: unsigned short weight; short server_type; int group_id; EndpointGroup *group; UPSAddrParams(const struct AddressParams *params, const std::string& address); }; class UPSGroupPolicy : public WFServiceGovernance { public: UPSGroupPolicy(); virtual ~UPSGroupPolicy(); public: virtual bool select(const ParsedURI& uri, WFNSTracing *tracing, EndpointAddress **addr); virtual void add_server(const std::string& address, const struct AddressParams *params); virtual int replace_server(const std::string& address, const struct AddressParams *params); void get_main_address(std::vector& addr_list); protected: struct rb_root group_map; EndpointGroup *default_group; private: virtual void recover_one_server(const EndpointAddress *addr); virtual void fuse_one_server(const EndpointAddress *addr); protected: virtual void add_server_locked(EndpointAddress *addr); virtual int remove_server_locked(const std::string& address); EndpointAddress *check_and_get(EndpointAddress *addr, bool addr_failed, WFNSTracing *tracing); bool is_alive(const EndpointAddress *addr) const; protected: EndpointAddress *consistent_hash_with_group(unsigned int hash, WFNSTracing *tracing); void hash_map_add_addr(EndpointAddress *addr); void hash_map_remove_addr(const std::string& address); std::map addr_hash; }; class UPSRoundRobinPolicy : public UPSGroupPolicy { public: UPSRoundRobinPolicy(bool try_another) : cur_idx(0) { this->try_another = try_another; } protected: virtual EndpointAddress *first_strategy(const ParsedURI& uri, WFNSTracing *tracing); virtual EndpointAddress *another_strategy(const ParsedURI& uri, WFNSTracing *tracing); protected: virtual int remove_server_locked(const std::string& address); protected: std::atomic cur_idx; }; class UPSWeightedRandomPolicy : public UPSGroupPolicy { public: UPSWeightedRandomPolicy(bool try_another) { this->total_weight = 0; this->available_weight = 0; this->try_another = try_another; } protected: virtual EndpointAddress *first_strategy(const ParsedURI& uri, WFNSTracing *tracing); virtual EndpointAddress *another_strategy(const ParsedURI& uri, WFNSTracing *tracing); protected: virtual void add_server_locked(EndpointAddress *addr); virtual int remove_server_locked(const std::string& address); int total_weight; int available_weight; private: virtual void recover_one_server(const EndpointAddress *addr); virtual void fuse_one_server(const EndpointAddress *addr); static int select_history_weight(WFNSTracing *tracing); }; class UPSVNSWRRPolicy : public UPSWeightedRandomPolicy { public: UPSVNSWRRPolicy() : UPSWeightedRandomPolicy(false) { this->cur_idx = 0; this->try_another = false; }; protected: virtual EndpointAddress *first_strategy(const ParsedURI& uri, WFNSTracing *tracing); private: virtual void add_server_locked(EndpointAddress *addr); virtual int remove_server_locked(const std::string& address); void init(); void init_virtual_nodes(); std::vector pre_generated_vec; std::vector current_weight_vec; std::atomic cur_idx; }; class UPSConsistentHashPolicy : public UPSGroupPolicy { public: UPSConsistentHashPolicy(upstream_route_t consistent_hash) : consistent_hash(std::move(consistent_hash)) { } protected: virtual EndpointAddress *first_strategy(const ParsedURI& uri, WFNSTracing *tracing); private: virtual void add_server_locked(EndpointAddress *addr); virtual int remove_server_locked(const std::string& address); upstream_route_t consistent_hash; }; class UPSManualPolicy : public UPSGroupPolicy { public: UPSManualPolicy(bool try_another, upstream_route_t select, upstream_route_t try_another_select) : manual_select(std::move(select)), another_select(std::move(try_another_select)) { this->try_another = try_another; } EndpointAddress *first_strategy(const ParsedURI& uri, WFNSTracing *tracing); EndpointAddress *another_strategy(const ParsedURI& uri, WFNSTracing *tracing); private: virtual void add_server_locked(EndpointAddress *addr); virtual int remove_server_locked(const std::string& address); upstream_route_t manual_select; upstream_route_t another_select; }; #endif workflow-0.11.8/src/nameservice/WFDnsResolver.cc000066400000000000000000000414011476003635400215650ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) Liu Kai (liukaidx@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com) */ #include #include #include #include #include #include #include #include #include #include #include #include #include "EndpointParams.h" #include "RouteManager.h" #include "WFGlobal.h" #include "WFTaskFactory.h" #include "WFResourcePool.h" #include "WFNameService.h" #include "DnsCache.h" #include "DnsUtil.h" #include "WFDnsClient.h" #include "WFDnsResolver.h" #define HOSTS_LINEBUF_INIT_SIZE 128 #define PORT_STR_MAX 5 class DnsInput { public: DnsInput() : port_(0), numeric_host_(false), family_(AF_UNSPEC) {} DnsInput(const std::string& host, unsigned short port, bool numeric_host, int family) : host_(host), port_(port), numeric_host_(numeric_host), family_(family) {} void reset(const std::string& host, unsigned short port) { host_.assign(host); port_ = port; numeric_host_ = false; family_ = AF_UNSPEC; } void reset(const std::string& host, unsigned short port, bool numeric_host, int family) { host_.assign(host); port_ = port; numeric_host_ = numeric_host; family_ = family; } const std::string& get_host() const { return host_; } unsigned short get_port() const { return port_; } bool is_numeric_host() const { return numeric_host_; } protected: std::string host_; unsigned short port_; bool numeric_host_; int family_; friend class DnsRoutine; }; class DnsOutput { public: DnsOutput(): error_(0), addrinfo_(NULL) {} ~DnsOutput() { if (addrinfo_) { if (addrinfo_->ai_flags) freeaddrinfo(addrinfo_); else free(addrinfo_); } } int get_error() const { return error_; } const struct addrinfo *get_addrinfo() const { return addrinfo_; } //if DONOT want DnsOutput release addrinfo, use move_addrinfo in callback struct addrinfo *move_addrinfo() { struct addrinfo *p = addrinfo_; addrinfo_ = NULL; return p; } protected: int error_; struct addrinfo *addrinfo_; friend class DnsRoutine; }; class DnsRoutine { public: static void run(const DnsInput *in, DnsOutput *out); static void create(DnsOutput *out, int error, struct addrinfo *ai) { if (out->addrinfo_) { if (out->addrinfo_->ai_flags) freeaddrinfo(out->addrinfo_); else free(out->addrinfo_); } out->error_ = error; out->addrinfo_ = ai; } private: static void run_local_path(const std::string& path, DnsOutput *out); }; void DnsRoutine::run_local_path(const std::string& path, DnsOutput *out) { struct sockaddr_un *sun = NULL; if (path.size() + 1 <= sizeof sun->sun_path) { size_t size = sizeof (struct addrinfo) + sizeof (struct sockaddr_un); out->addrinfo_ = (struct addrinfo *)calloc(size, 1); if (out->addrinfo_) { sun = (struct sockaddr_un *)(out->addrinfo_ + 1); sun->sun_family = AF_UNIX; memcpy(sun->sun_path, path.c_str(), path.size()); out->addrinfo_->ai_family = AF_UNIX; out->addrinfo_->ai_socktype = SOCK_STREAM; out->addrinfo_->ai_addr = (struct sockaddr *)sun; size = offsetof(struct sockaddr_un, sun_path) + path.size() + 1; out->addrinfo_->ai_addrlen = size; out->error_ = 0; return; } } else errno = EINVAL; out->error_ = EAI_SYSTEM; } void DnsRoutine::run(const DnsInput *in, DnsOutput *out) { if (in->host_[0] == '/') { run_local_path(in->host_, out); return; } struct addrinfo hints = { .ai_flags = AI_ADDRCONFIG | AI_NUMERICSERV, .ai_family = in->family_, .ai_socktype = SOCK_STREAM, }; char port_str[PORT_STR_MAX + 1]; if (in->is_numeric_host()) hints.ai_flags |= AI_NUMERICHOST; snprintf(port_str, PORT_STR_MAX + 1, "%u", in->port_); out->error_ = getaddrinfo(in->host_.c_str(), port_str, &hints, &out->addrinfo_); if (out->error_ == 0) out->addrinfo_->ai_flags = 1; } // Dns Thread task. For internal usage only. using ThreadDnsTask = WFThreadTask; using thread_dns_callback_t = std::function; struct DnsContext { unsigned short port; int eai_error; struct addrinfo *ai; }; static int __default_family() { struct addrinfo hints = { .ai_flags = AI_ADDRCONFIG, .ai_family = AF_UNSPEC, .ai_socktype = SOCK_STREAM, }; struct addrinfo *res; struct addrinfo *cur; int family = AF_UNSPEC; bool v4 = false; bool v6 = false; if (getaddrinfo(NULL, "1", &hints, &res) == 0) { for (cur = res; cur; cur = cur->ai_next) { if (cur->ai_family == AF_INET) v4 = true; else if (cur->ai_family == AF_INET6) v6 = true; } freeaddrinfo(res); if (v4 ^ v6) family = v4 ? AF_INET : AF_INET6; } return family; } // hosts line format: IP canonical_name [aliases...] [# Comment] static int __readaddrinfo_line(char *p, const char *name, const char *port, const struct addrinfo *hints, struct addrinfo **res) { const char *ip = NULL; char *start; start = p; while (*start != '\0' && *start != '#') start++; *start = '\0'; while (1) { while (isspace(*p)) p++; start = p; while (*p != '\0' && !isspace(*p)) p++; if (start == p) break; if (*p != '\0') *p++ = '\0'; if (ip == NULL) { ip = start; continue; } if (strcasecmp(name, start) == 0) { if (getaddrinfo(ip, port, hints, res) == 0) return 0; } } return 1; } static int __readaddrinfo(const char *path, const char *name, unsigned short port, const struct addrinfo *hints, struct addrinfo **res) { char port_str[PORT_STR_MAX + 1]; size_t bufsize = 0; char *line = NULL; int count = 0; int errno_bak; FILE *fp; int ret; fp = fopen(path, "r"); if (!fp) return EAI_SYSTEM; snprintf(port_str, PORT_STR_MAX + 1, "%u", port); errno_bak = errno; while ((ret = getline(&line, &bufsize, fp)) > 0) { if (__readaddrinfo_line(line, name, port_str, hints, res) == 0) { count++; res = &(*res)->ai_next; } } ret = ferror(fp) ? EAI_SYSTEM : EAI_NONAME; free(line); fclose(fp); if (count != 0) { errno = errno_bak; return 0; } return ret; } static ThreadDnsTask *__create_thread_dns_task(const std::string& host, unsigned short port, int family, thread_dns_callback_t callback) { auto *task = WFThreadTaskFactory:: create_thread_task(WFGlobal::get_dns_queue(), WFGlobal::get_dns_executor(), DnsRoutine::run, std::move(callback)); task->get_input()->reset(host, port, false, family); return task; } static std::string __get_cache_host(const std::string& hostname, int family) { char c; if (family == AF_UNSPEC) c = '*'; else if (family == AF_INET) c = '4'; else if (family == AF_INET6) c = '6'; else c = '?'; return hostname + c; } static std::string __get_guard_name(const std::string& cache_host, unsigned short port) { std::string guard_name("INTERNAL-dns:"); guard_name.append(cache_host).append(":"); guard_name.append(std::to_string(port)); return guard_name; } void WFResolverTask::dispatch() { if (this->msg_) { this->state = WFT_STATE_DNS_ERROR; this->error = (intptr_t)msg_; this->subtask_done(); return; } const ParsedURI& uri = ns_params_.uri; host_ = uri.host ? uri.host : ""; port_ = uri.port ? atoi(uri.port) : 0; DnsCache *dns_cache = WFGlobal::get_dns_cache(); const DnsCache::DnsHandle *addr_handle; std::string hostname = host_; int family = ep_params_.address_family; std::string cache_host = __get_cache_host(hostname, family); if (ns_params_.retry_times == 0) addr_handle = dns_cache->get_ttl(cache_host, port_); else addr_handle = dns_cache->get_confident(cache_host, port_); if (in_guard_ && (addr_handle == NULL || addr_handle->value.delayed())) { if (addr_handle) dns_cache->release(addr_handle); this->request_dns(); return; } if (addr_handle) { RouteManager *route_manager = WFGlobal::get_route_manager(); struct addrinfo *addrinfo = addr_handle->value.addrinfo; struct addrinfo first; if (ns_params_.fixed_addr && addrinfo->ai_next) { first = *addrinfo; first.ai_next = NULL; addrinfo = &first; } if (route_manager->get(ns_params_.type, addrinfo, ns_params_.info, &ep_params_, hostname, ns_params_.ssl_ctx, this->result) < 0) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; } else this->state = WFT_STATE_SUCCESS; dns_cache->release(addr_handle); this->subtask_done(); return; } if (*host_) { char front = host_[0]; char back = host_[hostname.size() - 1]; struct in6_addr addr; int ret; if (strchr(host_, ':')) ret = inet_pton(AF_INET6, host_, &addr); else if (isdigit(back) && isdigit(front)) ret = inet_pton(AF_INET, host_, &addr); else if (front == '/') ret = 1; else ret = 0; if (ret == 1) { // 'true' means numeric host DnsInput dns_in(hostname, port_, true, AF_UNSPEC); DnsOutput dns_out; DnsRoutine::run(&dns_in, &dns_out); dns_callback_internal(&dns_out, (unsigned int)-1, (unsigned int)-1); this->subtask_done(); return; } } const char *hosts = WFGlobal::get_global_settings()->hosts_path; if (hosts) { struct addrinfo hints = { .ai_flags = AI_ADDRCONFIG | AI_NUMERICSERV | AI_NUMERICHOST, .ai_family = ep_params_.address_family, .ai_socktype = SOCK_STREAM, }; struct addrinfo *ai; int ret; ret = __readaddrinfo(hosts, host_, port_, &hints, &ai); if (ret == 0) { DnsOutput out; DnsRoutine::create(&out, ret, ai); dns_callback_internal(&out, dns_ttl_default_, dns_ttl_min_); this->subtask_done(); return; } } std::string guard_name = __get_guard_name(cache_host, port_); WFConditional *guard = WFTaskFactory::create_guard(guard_name, this, &msg_); in_guard_ = true; has_next_ = true; series_of(this)->push_front(guard); this->subtask_done(); } void WFResolverTask::request_dns() { WFDnsClient *client = WFGlobal::get_dns_client(); if (client) { static int default_family = __default_family(); WFResourcePool *respool = WFGlobal::get_dns_respool(); int family = ep_params_.address_family; if (family == AF_UNSPEC) family = default_family; if (family == AF_INET || family == AF_INET6) { auto&& cb = std::bind(&WFResolverTask::dns_single_callback, this, std::placeholders::_1); WFDnsTask *dns_task = client->create_dns_task(host_, std::move(cb)); if (family == AF_INET6) dns_task->get_req()->set_question_type(DNS_TYPE_AAAA); WFConditional *cond = respool->get(dns_task); series_of(this)->push_front(cond); } else { struct DnsContext *dctx = new struct DnsContext[2]; WFDnsTask *task_v4; WFDnsTask *task_v6; ParallelWork *pwork; dctx[0].ai = NULL; dctx[1].ai = NULL; dctx[0].port = port_; dctx[1].port = port_; task_v4 = client->create_dns_task(host_, dns_partial_callback); task_v4->user_data = dctx; task_v6 = client->create_dns_task(host_, dns_partial_callback); task_v6->get_req()->set_question_type(DNS_TYPE_AAAA); task_v6->user_data = dctx + 1; auto&& cb = std::bind(&WFResolverTask::dns_parallel_callback, this, std::placeholders::_1); pwork = Workflow::create_parallel_work(std::move(cb)); pwork->set_context(dctx); WFConditional *cond_v4 = respool->get(task_v4); WFConditional *cond_v6 = respool->get(task_v6); pwork->add_series(Workflow::create_series_work(cond_v4, nullptr)); pwork->add_series(Workflow::create_series_work(cond_v6, nullptr)); series_of(this)->push_front(pwork); } } else { ThreadDnsTask *dns_task; auto&& cb = std::bind(&WFResolverTask::thread_dns_callback, this, std::placeholders::_1); dns_task = __create_thread_dns_task(host_, port_, ep_params_.address_family, std::move(cb)); series_of(this)->push_front(dns_task); } has_next_ = true; this->subtask_done(); } SubTask *WFResolverTask::done() { SeriesWork *series = series_of(this); if (!has_next_) task_callback(); else has_next_ = false; return series->pop(); } void WFResolverTask::dns_callback_internal(void *thrd_dns_output, unsigned int ttl_default, unsigned int ttl_min) { DnsOutput *dns_out = (DnsOutput *)thrd_dns_output; int dns_error = dns_out->get_error(); if (dns_error) { if (dns_error == EAI_SYSTEM) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; } else { this->state = WFT_STATE_DNS_ERROR; this->error = dns_error; } } else { RouteManager *route_manager = WFGlobal::get_route_manager(); DnsCache *dns_cache = WFGlobal::get_dns_cache(); struct addrinfo *addrinfo = dns_out->move_addrinfo(); const DnsCache::DnsHandle *addr_handle; std::string hostname = host_; int family = ep_params_.address_family; std::string cache_host = __get_cache_host(hostname, family); addr_handle = dns_cache->put(cache_host, port_, addrinfo, (unsigned int)ttl_default, (unsigned int)ttl_min); if (route_manager->get(ns_params_.type, addrinfo, ns_params_.info, &ep_params_, hostname, ns_params_.ssl_ctx, this->result) < 0) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; } else this->state = WFT_STATE_SUCCESS; dns_cache->release(addr_handle); } } void WFResolverTask::dns_single_callback(void *net_dns_task) { WFDnsTask *dns_task = (WFDnsTask *)net_dns_task; WFGlobal::get_dns_respool()->post(NULL); if (dns_task->get_state() == WFT_STATE_SUCCESS) { struct addrinfo *ai = NULL; int ret; ret = protocol::DnsUtil::getaddrinfo(dns_task->get_resp(), port_, &ai); DnsOutput out; DnsRoutine::create(&out, ret, ai); dns_callback_internal(&out, dns_ttl_default_, dns_ttl_min_); } else { this->state = WFT_STATE_DNS_ERROR; this->error = EAI_AGAIN; } task_callback(); } void WFResolverTask::dns_partial_callback(void *net_dns_task) { WFDnsTask *dns_task = (WFDnsTask *)net_dns_task; WFGlobal::get_dns_respool()->post(NULL); struct DnsContext *ctx = (struct DnsContext *)dns_task->user_data; ctx->ai = NULL; if (dns_task->get_state() == WFT_STATE_SUCCESS) { protocol::DnsResponse *resp = dns_task->get_resp(); ctx->eai_error = protocol::DnsUtil::getaddrinfo(resp, ctx->port, &ctx->ai); } else ctx->eai_error = EAI_AGAIN; } void WFResolverTask::dns_parallel_callback(const void *parallel) { const ParallelWork *pwork = (const ParallelWork *)parallel; struct DnsContext *c4 = (struct DnsContext *)pwork->get_context(); struct DnsContext *c6 = c4 + 1; if (c4->eai_error == 0 || c6->eai_error == 0) { struct addrinfo *ai = NULL; struct addrinfo **pai = &ai; DnsOutput out; *pai = c4->ai; while (*pai) pai = &(*pai)->ai_next; *pai = c6->ai; DnsRoutine::create(&out, 0, ai); dns_callback_internal(&out, dns_ttl_default_, dns_ttl_min_); } else { int eai_error = c4->eai_error; if (c6->eai_error == EAI_AGAIN) eai_error = EAI_AGAIN; this->state = WFT_STATE_DNS_ERROR; this->error = eai_error; } delete []c4; task_callback(); } void WFResolverTask::thread_dns_callback(void *thrd_dns_task) { ThreadDnsTask *dns_task = (ThreadDnsTask *)thrd_dns_task; if (dns_task->get_state() == WFT_STATE_SUCCESS) { DnsOutput *out = dns_task->get_output(); dns_callback_internal(out, dns_ttl_default_, dns_ttl_min_); } else { this->state = dns_task->get_state(); this->error = dns_task->get_error(); } task_callback(); } void WFResolverTask::task_callback() { if (in_guard_) { int family = ep_params_.address_family; std::string cache_host = __get_cache_host(host_, family); std::string guard_name = __get_guard_name(cache_host, port_); if (this->state == WFT_STATE_DNS_ERROR) msg_ = (void *)(intptr_t)this->error; WFTaskFactory::release_guard_safe(guard_name, msg_); } if (this->callback) this->callback(this); delete this; } WFRouterTask *WFDnsResolver::create_router_task(const struct WFNSParams *params, router_callback_t callback) { const struct WFGlobalSettings *settings = WFGlobal::get_global_settings(); unsigned int dns_ttl_default = settings->dns_ttl_default; unsigned int dns_ttl_min = settings->dns_ttl_min; const struct EndpointParams *ep_params = &settings->endpoint_params; return new WFResolverTask(params, dns_ttl_default, dns_ttl_min, ep_params, std::move(callback)); } workflow-0.11.8/src/nameservice/WFDnsResolver.h000066400000000000000000000046511476003635400214350ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _WFDNSRESOLVER_H_ #define _WFDNSRESOLVER_H_ #include #include #include "EndpointParams.h" #include "WFNameService.h" class WFResolverTask : public WFRouterTask { public: WFResolverTask(const struct WFNSParams *ns_params, unsigned int dns_ttl_default, unsigned int dns_ttl_min, const struct EndpointParams *ep_params, router_callback_t&& cb) : WFRouterTask(std::move(cb)), ns_params_(*ns_params), ep_params_(*ep_params) { if (ns_params_.fixed_conn) ep_params_.max_connections = 1; dns_ttl_default_ = dns_ttl_default; dns_ttl_min_ = dns_ttl_min; has_next_ = false; in_guard_ = false; msg_ = NULL; } WFResolverTask(const struct WFNSParams *ns_params, router_callback_t&& cb) : WFRouterTask(std::move(cb)), ns_params_(*ns_params) { if (ns_params_.fixed_conn) ep_params_.max_connections = 1; has_next_ = false; in_guard_ = false; msg_ = NULL; } protected: virtual void dispatch(); virtual SubTask *done(); void set_has_next() { has_next_ = true; } private: void thread_dns_callback(void *thrd_dns_task); void dns_single_callback(void *net_dns_task); static void dns_partial_callback(void *net_dns_task); void dns_parallel_callback(const void *parallel); void dns_callback_internal(void *thrd_dns_output, unsigned int ttl_default, unsigned int ttl_min); void request_dns(); void task_callback(); protected: struct WFNSParams ns_params_; unsigned int dns_ttl_default_; unsigned int dns_ttl_min_; struct EndpointParams ep_params_; private: const char *host_; unsigned short port_; bool has_next_; bool in_guard_; void *msg_; }; class WFDnsResolver : public WFNSPolicy { public: virtual WFRouterTask *create_router_task(const struct WFNSParams *params, router_callback_t callback); }; #endif workflow-0.11.8/src/nameservice/WFNameService.cc000066400000000000000000000057411476003635400215270ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include #include "rbtree.h" #include "WFNameService.h" struct WFNSPolicyEntry { struct rb_node rb; WFNSPolicy *policy; char name[1]; }; int WFNameService::add_policy(const char *name, WFNSPolicy *policy) { struct rb_node **p = &this->root.rb_node; struct rb_node *parent = NULL; struct WFNSPolicyEntry *entry; int n, ret = -1; pthread_rwlock_wrlock(&this->rwlock); while (*p) { parent = *p; entry = rb_entry(*p, struct WFNSPolicyEntry, rb); n = strcasecmp(name, entry->name); if (n < 0) p = &(*p)->rb_left; else if (n > 0) p = &(*p)->rb_right; else break; } if (!*p) { size_t len = strlen(name); size_t size = offsetof(struct WFNSPolicyEntry, name) + len + 1; entry = (struct WFNSPolicyEntry *)malloc(size); if (entry) { memcpy(entry->name, name, len + 1); entry->policy = policy; rb_link_node(&entry->rb, parent, p); rb_insert_color(&entry->rb, &this->root); ret = 0; } } else errno = EEXIST; pthread_rwlock_unlock(&this->rwlock); return ret; } inline struct WFNSPolicyEntry *WFNameService::get_policy_entry(const char *name) { struct rb_node *p = this->root.rb_node; struct WFNSPolicyEntry *entry; int n; while (p) { entry = rb_entry(p, struct WFNSPolicyEntry, rb); n = strcasecmp(name, entry->name); if (n < 0) p = p->rb_left; else if (n > 0) p = p->rb_right; else return entry; } return NULL; } WFNSPolicy *WFNameService::get_policy(const char *name) { WFNSPolicy *policy = this->default_policy; struct WFNSPolicyEntry *entry; if (this->root.rb_node) { pthread_rwlock_rdlock(&this->rwlock); entry = this->get_policy_entry(name); if (entry) policy = entry->policy; pthread_rwlock_unlock(&this->rwlock); } return policy; } WFNSPolicy *WFNameService::del_policy(const char *name) { WFNSPolicy *policy = NULL; struct WFNSPolicyEntry *entry; pthread_rwlock_wrlock(&this->rwlock); entry = this->get_policy_entry(name); if (entry) { policy = entry->policy; rb_erase(&entry->rb, &this->root); } pthread_rwlock_unlock(&this->rwlock); free(entry); return policy; } WFNameService::~WFNameService() { struct WFNSPolicyEntry *entry; while (this->root.rb_node) { entry = rb_entry(this->root.rb_node, struct WFNSPolicyEntry, rb); rb_erase(&entry->rb, &this->root); free(entry); } } workflow-0.11.8/src/nameservice/WFNameService.h000066400000000000000000000061351476003635400213670ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _WFNAMESERVICE_H_ #define _WFNAMESERVICE_H_ #include #include #include #include "rbtree.h" #include "Communicator.h" #include "Workflow.h" #include "WFTask.h" #include "RouteManager.h" #include "URIParser.h" #include "EndpointParams.h" class WFRouterTask : public WFGenericTask { public: RouteManager::RouteResult *get_result() { return &this->result; } public: void set_state(int state) { this->state = state; } void set_error(int error) { this->error = error; } protected: RouteManager::RouteResult result; std::function callback; protected: virtual SubTask *done() { SeriesWork *series = series_of(this); if (this->callback) this->callback(this); delete this; return series->pop(); } public: WFRouterTask(std::function&& cb) : callback(std::move(cb)) { } }; class WFNSTracing { public: void *data; void (*deleter)(void *); public: WFNSTracing() { this->data = NULL; this->deleter = NULL; } }; struct WFNSParams { enum TransportType type; ParsedURI& uri; const char *info; SSL_CTX *ssl_ctx; bool fixed_addr; bool fixed_conn; int retry_times; WFNSTracing *tracing; }; using router_callback_t = std::function; class WFNSPolicy { public: virtual WFRouterTask *create_router_task(const struct WFNSParams *params, router_callback_t callback) = 0; virtual void success(RouteManager::RouteResult *result, WFNSTracing *tracing, CommTarget *target) { RouteManager::notify_available(result->cookie, target); } virtual void failed(RouteManager::RouteResult *result, WFNSTracing *tracing, CommTarget *target) { if (target) RouteManager::notify_unavailable(result->cookie, target); } public: virtual ~WFNSPolicy() { } }; class WFNameService { public: int add_policy(const char *name, WFNSPolicy *policy); WFNSPolicy *get_policy(const char *name); WFNSPolicy *del_policy(const char *name); public: WFNSPolicy *get_default_policy() const { return this->default_policy; } void set_default_policy(WFNSPolicy *policy) { this->default_policy = policy; } private: WFNSPolicy *default_policy; struct rb_root root; pthread_rwlock_t rwlock; private: struct WFNSPolicyEntry *get_policy_entry(const char *name); public: WFNameService(WFNSPolicy *default_policy) : rwlock(PTHREAD_RWLOCK_INITIALIZER) { this->root.rb_node = NULL; this->default_policy = default_policy; } virtual ~WFNameService(); }; #endif workflow-0.11.8/src/nameservice/WFServiceGovernance.cc000066400000000000000000000300261476003635400227300ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include #include #include "URIParser.h" #include "WFTaskError.h" #include "StringUtil.h" #include "WFGlobal.h" #include "WFNameService.h" #include "WFDnsResolver.h" #include "WFServiceGovernance.h" #define GET_CURRENT_SECOND std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()).count() #define DNS_CACHE_LEVEL_1 1 #define DNS_CACHE_LEVEL_2 2 #define MTTR_SECONDS_DEFAULT 30 WFServiceGovernance::WFServiceGovernance() : breaker_lock(PTHREAD_MUTEX_INITIALIZER), rwlock(PTHREAD_RWLOCK_INITIALIZER) { this->nalives = 0; this->try_another = false; this->mttr_seconds = MTTR_SECONDS_DEFAULT; INIT_LIST_HEAD(&this->breaker_list); } WFServiceGovernance::~WFServiceGovernance() { for (EndpointAddress *addr : this->servers) delete addr; } PolicyAddrParams::PolicyAddrParams() { const struct AddressParams *params = &ADDRESS_PARAMS_DEFAULT; this->endpoint_params = params->endpoint_params; this->dns_ttl_default = params->dns_ttl_default; this->dns_ttl_min = params->dns_ttl_min; this->max_fails = params->max_fails; } PolicyAddrParams::PolicyAddrParams(const struct AddressParams *params) : endpoint_params(params->endpoint_params) { this->dns_ttl_default = params->dns_ttl_default; this->dns_ttl_min = params->dns_ttl_min; this->max_fails = params->max_fails; } EndpointAddress::EndpointAddress(const std::string& address, PolicyAddrParams *address_params) { std::vector arr = StringUtil::split(address, ':'); this->params = address_params; if (this->params->max_fails == 0) this->params->max_fails = 1; this->address = address; this->fail_count = 0; this->ref = 1; this->entry.list.next = NULL; this->entry.ptr = this; if (arr.size() == 0) this->host = ""; else this->host = arr[0]; if (arr.size() <= 1) this->port = ""; else this->port = arr[1]; } class WFSGResolverTask : public WFResolverTask { public: WFSGResolverTask(const struct WFNSParams *params, WFServiceGovernance *sg, router_callback_t&& cb) : WFResolverTask(params, std::move(cb)) { sg_ = sg; } protected: virtual void dispatch(); protected: WFServiceGovernance *sg_; }; static void copy_host_port(ParsedURI& uri, const EndpointAddress *addr) { if (!addr->host.empty()) { free(uri.host); uri.host = strdup(addr->host.c_str()); } if (!addr->port.empty()) { free(uri.port); uri.port = strdup(addr->port.c_str()); } } void WFSGResolverTask::dispatch() { WFNSTracing *tracing = ns_params_.tracing; EndpointAddress *addr; if (!sg_) { this->WFResolverTask::dispatch(); return; } if (sg_->pre_select_) { WFConditional *cond = sg_->pre_select_(this); if (cond) { series_of(this)->push_front(cond); this->set_has_next(); this->subtask_done(); return; } else if (this->state != WFT_STATE_UNDEFINED) { this->subtask_done(); return; } } if (sg_->select(ns_params_.uri, tracing, &addr)) { auto *tracing_data = (WFServiceGovernance::TracingData *)tracing->data; if (!tracing_data) { tracing_data = new WFServiceGovernance::TracingData; tracing_data->sg = sg_; tracing->data = tracing_data; tracing->deleter = WFServiceGovernance::tracing_deleter; } tracing_data->history.push_back(addr); sg_ = NULL; copy_host_port(ns_params_.uri, addr); dns_ttl_default_ = addr->params->dns_ttl_default; dns_ttl_min_ = addr->params->dns_ttl_min; ep_params_ = addr->params->endpoint_params; this->WFResolverTask::dispatch(); } else { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_UPSTREAM_UNAVAILABLE; this->subtask_done(); } } WFRouterTask *WFServiceGovernance::create_router_task(const struct WFNSParams *params, router_callback_t callback) { return new WFSGResolverTask(params, this, std::move(callback)); } void WFServiceGovernance::tracing_deleter(void *data) { struct TracingData *tracing_data = (struct TracingData *)data; for (EndpointAddress *addr : tracing_data->history) { if (--addr->ref == 0) { pthread_rwlock_wrlock(&tracing_data->sg->rwlock); tracing_data->sg->pre_delete_server(addr); pthread_rwlock_unlock(&tracing_data->sg->rwlock); delete addr; } } delete tracing_data; } bool WFServiceGovernance::in_select_history(WFNSTracing *tracing, EndpointAddress *addr) { struct TracingData *tracing_data = (struct TracingData *)tracing->data; if (!tracing_data) return false; for (EndpointAddress *server : tracing_data->history) { if (server == addr) return true; } return false; } void WFServiceGovernance::recover_server_from_breaker(EndpointAddress *addr) { addr->fail_count = 0; pthread_mutex_lock(&this->breaker_lock); if (addr->entry.list.next) { list_del(&addr->entry.list); addr->entry.list.next = NULL; this->recover_one_server(addr); } pthread_mutex_unlock(&this->breaker_lock); } void WFServiceGovernance::fuse_server_to_breaker(EndpointAddress *addr) { pthread_mutex_lock(&this->breaker_lock); if (!addr->entry.list.next) { addr->broken_timeout = GET_CURRENT_SECOND + this->mttr_seconds; list_add_tail(&addr->entry.list, &this->breaker_list); this->fuse_one_server(addr); } pthread_mutex_unlock(&this->breaker_lock); } void WFServiceGovernance::pre_delete_server(EndpointAddress *addr) { pthread_mutex_lock(&this->breaker_lock); if (addr->entry.list.next) { list_del(&addr->entry.list); addr->entry.list.next = NULL; } else this->fuse_one_server(addr); pthread_mutex_unlock(&this->breaker_lock); } void WFServiceGovernance::success(RouteManager::RouteResult *result, WFNSTracing *tracing, CommTarget *target) { struct TracingData *tracing_data = (struct TracingData *)tracing->data; auto *v = &tracing_data->history; EndpointAddress *server = (*v)[v->size() - 1]; pthread_rwlock_wrlock(&this->rwlock); this->recover_server_from_breaker(server); pthread_rwlock_unlock(&this->rwlock); this->WFNSPolicy::success(result, tracing, target); } void WFServiceGovernance::failed(RouteManager::RouteResult *result, WFNSTracing *tracing, CommTarget *target) { struct TracingData *tracing_data = (struct TracingData *)tracing->data; auto *v = &tracing_data->history; EndpointAddress *server = (*v)[v->size() - 1]; pthread_rwlock_wrlock(&this->rwlock); if (++server->fail_count == server->params->max_fails) this->fuse_server_to_breaker(server); pthread_rwlock_unlock(&this->rwlock); this->WFNSPolicy::failed(result, tracing, target); } void WFServiceGovernance::check_breaker_locked(int64_t cur_time) { struct list_head *pos, *tmp; struct EndpointAddress::address_entry *entry; EndpointAddress *addr; list_for_each_safe(pos, tmp, &this->breaker_list) { entry = list_entry(pos, struct EndpointAddress::address_entry, list); addr = entry->ptr; if (cur_time >= addr->broken_timeout) { addr->fail_count = addr->params->max_fails - 1; this->recover_one_server(addr); list_del(pos); pos->next = NULL; } else break; } } void WFServiceGovernance::check_breaker() { pthread_mutex_lock(&this->breaker_lock); if (!list_empty(&this->breaker_list)) this->check_breaker_locked(GET_CURRENT_SECOND); pthread_mutex_unlock(&this->breaker_lock); } void WFServiceGovernance::try_clear_breaker() { pthread_mutex_lock(&this->breaker_lock); if (!list_empty(&this->breaker_list)) { struct list_head *pos = this->breaker_list.next; struct EndpointAddress::address_entry *entry; entry = list_entry(pos, struct EndpointAddress::address_entry, list); if (GET_CURRENT_SECOND >= entry->ptr->broken_timeout) this->check_breaker_locked(INT64_MAX); } pthread_mutex_unlock(&this->breaker_lock); } EndpointAddress *WFServiceGovernance::first_strategy(const ParsedURI& uri, WFNSTracing *tracing) { unsigned int idx = rand() % this->servers.size(); return this->servers[idx]; } EndpointAddress *WFServiceGovernance::another_strategy(const ParsedURI& uri, WFNSTracing *tracing) { return this->first_strategy(uri, tracing); } bool WFServiceGovernance::select(const ParsedURI& uri, WFNSTracing *tracing, EndpointAddress **addr) { pthread_rwlock_rdlock(&this->rwlock); unsigned int n = (unsigned int)this->servers.size(); if (n == 0) { pthread_rwlock_unlock(&this->rwlock); return false; } this->check_breaker(); if (this->nalives == 0) { pthread_rwlock_unlock(&this->rwlock); return false; } EndpointAddress *select_addr = this->first_strategy(uri, tracing); if (!select_addr || select_addr->fail_count >= select_addr->params->max_fails) { if (this->try_another) select_addr = this->another_strategy(uri, tracing); } if (select_addr) { *addr = select_addr; ++select_addr->ref; } pthread_rwlock_unlock(&this->rwlock); return !!select_addr; } void WFServiceGovernance::add_server_locked(EndpointAddress *addr) { this->server_map[addr->address].push_back(addr); this->servers.push_back(addr); this->recover_one_server(addr); } int WFServiceGovernance::remove_server_locked(const std::string& address) { const auto map_it = this->server_map.find(address); size_t n = this->servers.size(); size_t new_n = 0; int ret = 0; for (size_t i = 0; i < n; i++) { if (this->servers[i]->address != address) this->servers[new_n++] = this->servers[i]; } this->servers.resize(new_n); if (map_it != this->server_map.cend()) { for (EndpointAddress *addr : map_it->second) { if (--addr->ref == 0) { this->pre_delete_server(addr); delete addr; } ret++; } this->server_map.erase(map_it); } return ret; } void WFServiceGovernance::add_server(const std::string& address, const AddressParams *params) { EndpointAddress *addr = new EndpointAddress(address, new PolicyAddrParams(params)); pthread_rwlock_wrlock(&this->rwlock); this->add_server_locked(addr); pthread_rwlock_unlock(&this->rwlock); } int WFServiceGovernance::remove_server(const std::string& address) { int ret; pthread_rwlock_wrlock(&this->rwlock); ret = this->remove_server_locked(address); pthread_rwlock_unlock(&this->rwlock); return ret; } int WFServiceGovernance::replace_server(const std::string& address, const AddressParams *params) { int ret; EndpointAddress *addr = new EndpointAddress(address, new PolicyAddrParams(params)); pthread_rwlock_wrlock(&this->rwlock); ret = this->remove_server_locked(address); this->add_server_locked(addr); pthread_rwlock_unlock(&this->rwlock); return ret; } void WFServiceGovernance::enable_server(const std::string& address) { pthread_rwlock_wrlock(&this->rwlock); const auto map_it = this->server_map.find(address); if (map_it != this->server_map.cend()) { for (EndpointAddress *addr : map_it->second) this->recover_server_from_breaker(addr); } pthread_rwlock_unlock(&this->rwlock); } void WFServiceGovernance::disable_server(const std::string& address) { pthread_rwlock_wrlock(&this->rwlock); const auto map_it = this->server_map.find(address); if (map_it != this->server_map.cend()) { for (EndpointAddress *addr : map_it->second) { addr->fail_count = addr->params->max_fails; this->fuse_server_to_breaker(addr); } } pthread_rwlock_unlock(&this->rwlock); } void WFServiceGovernance::get_current_address(std::vector& addr_list) { pthread_rwlock_rdlock(&this->rwlock); for (const EndpointAddress *server : this->servers) addr_list.push_back(server->address); pthread_rwlock_unlock(&this->rwlock); } workflow-0.11.8/src/nameservice/WFServiceGovernance.h000066400000000000000000000127271476003635400226020ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) Xie Han (xiehan@sogou-inc.com) */ #ifndef _WFSERVICEGOVERNANCE_H_ #define _WFSERVICEGOVERNANCE_H_ #include #include #include #include #include #include #include "URIParser.h" #include "EndpointParams.h" #include "WFNameService.h" struct AddressParams { struct EndpointParams endpoint_params; ///< Connection config unsigned int dns_ttl_default; ///< in seconds, DNS TTL when network request success unsigned int dns_ttl_min; ///< in seconds, DNS TTL when network request fail /** * - The max_fails directive sets the number of consecutive unsuccessful attempts to communicate with the server. * - After 30s following the server failure, upstream probe the server with some live client`s requests. * - If the probes have been successful, the server is marked as a live one. * - If max_fails is set to 1, it means server would out of upstream selection in 30 seconds when failed only once */ unsigned int max_fails; ///< [1, INT32_MAX] max_fails = 0 means max_fails = 1 unsigned short weight; ///< [1, 65535] weight = 0 means weight = 1. only for main server int server_type; ///< 0 for main and 1 for backup int group_id; ///< -1 means no group. Backup without group will be backup for any main }; static constexpr struct AddressParams ADDRESS_PARAMS_DEFAULT = { .endpoint_params = ENDPOINT_PARAMS_DEFAULT, .dns_ttl_default = 12 * 3600, .dns_ttl_min = 180, .max_fails = 200, .weight = 1, .server_type = 0, /* 0 for main and 1 for backup. */ .group_id = -1, }; class PolicyAddrParams { public: struct EndpointParams endpoint_params; unsigned int dns_ttl_default; unsigned int dns_ttl_min; unsigned int max_fails; public: PolicyAddrParams(); PolicyAddrParams(const struct AddressParams *params); virtual ~PolicyAddrParams() { } }; class EndpointAddress { public: std::string address; std::string host; std::string port; unsigned int fail_count; std::atomic ref; long long broken_timeout; PolicyAddrParams *params; struct address_entry { struct list_head list; EndpointAddress *ptr; } entry; public: EndpointAddress(const std::string& address, PolicyAddrParams *params); virtual ~EndpointAddress() { delete this->params; } }; class WFServiceGovernance : public WFNSPolicy { public: virtual WFRouterTask *create_router_task(const struct WFNSParams *params, router_callback_t callback); virtual void success(RouteManager::RouteResult *result, WFNSTracing *tracing, CommTarget *target); virtual void failed(RouteManager::RouteResult *result, WFNSTracing *tracing, CommTarget *target); virtual void add_server(const std::string& address, const struct AddressParams *params); int remove_server(const std::string& address); virtual int replace_server(const std::string& address, const struct AddressParams *params); void enable_server(const std::string& address); void disable_server(const std::string& address); virtual void get_current_address(std::vector& addr_list); void set_mttr_seconds(unsigned int seconds) { this->mttr_seconds = seconds; } static bool in_select_history(WFNSTracing *tracing, EndpointAddress *addr); public: using pre_select_t = std::function; void set_pre_select(pre_select_t pre_select) { pre_select_ = std::move(pre_select); } private: virtual bool select(const ParsedURI& uri, WFNSTracing *tracing, EndpointAddress **addr); virtual void recover_one_server(const EndpointAddress *addr) { this->nalives++; } virtual void fuse_one_server(const EndpointAddress *addr) { this->nalives--; } virtual void add_server_locked(EndpointAddress *addr); virtual int remove_server_locked(const std::string& address); void recover_server_from_breaker(EndpointAddress *addr); void fuse_server_to_breaker(EndpointAddress *addr); void check_breaker_locked(int64_t cur_time); private: struct list_head breaker_list; pthread_mutex_t breaker_lock; unsigned int mttr_seconds; pre_select_t pre_select_; protected: virtual EndpointAddress *first_strategy(const ParsedURI& uri, WFNSTracing *tracing); virtual EndpointAddress *another_strategy(const ParsedURI& uri, WFNSTracing *tracing); void check_breaker(); void try_clear_breaker(); void pre_delete_server(EndpointAddress *addr); struct TracingData { std::vector history; WFServiceGovernance *sg; }; static void tracing_deleter(void *data); std::vector servers; std::unordered_map> server_map; pthread_rwlock_t rwlock; std::atomic nalives; bool try_another; public: WFServiceGovernance(); virtual ~WFServiceGovernance(); friend class WFSGResolverTask; }; #endif workflow-0.11.8/src/nameservice/xmake.lua000066400000000000000000000002711476003635400203630ustar00rootroot00000000000000target("nameservice") add_files("*.cc") set_kind("object") if not has_config("upstream") then remove_files("WFServiceGovernance.cc", "UpstreamPolicies.cc") end workflow-0.11.8/src/protocol/000077500000000000000000000000001476003635400161135ustar00rootroot00000000000000workflow-0.11.8/src/protocol/CMakeLists.txt000066400000000000000000000013511476003635400206530ustar00rootroot00000000000000cmake_minimum_required(VERSION 3.6) project(protocol) set(SRC PackageWrapper.cc SSLWrapper.cc dns_parser.c DnsMessage.cc DnsUtil.cc http_parser.c HttpMessage.cc HttpUtil.cc TLVMessage.cc ) if (NOT MYSQL STREQUAL "n") set(SRC ${SRC} mysql_stream.c mysql_parser.c mysql_byteorder.c MySQLMessage.cc MySQLResult.cc MySQLUtil.cc ) endif () if (NOT REDIS STREQUAL "n") set(SRC ${SRC} redis_parser.c RedisMessage.cc ) endif () add_library(${PROJECT_NAME} OBJECT ${SRC}) if (KAFKA STREQUAL "y") set(SRC kafka_parser.c KafkaMessage.cc KafkaDataTypes.cc KafkaResult.cc ) add_library("protocol_kafka" OBJECT ${SRC}) set_property(SOURCE KafkaMessage.cc APPEND PROPERTY COMPILE_OPTIONS "-fno-rtti") endif () workflow-0.11.8/src/protocol/ConsulDataTypes.h000066400000000000000000000204721476003635400213530ustar00rootroot00000000000000/* Copyright (c) 2022 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wang Zhenpeng (wangzhenpeng@sogou-inc.com) */ #ifndef _CONSULDATATYPES_H_ #define _CONSULDATATYPES_H_ #include #include #include #include #include namespace protocol { class ConsulConfig { public: // common config void set_token(const std::string& token) { this->ptr->token = token; } std::string get_token() const { return this->ptr->token; } // discover config void set_datacenter(const std::string& data_center) { this->ptr->dc = data_center; } std::string get_datacenter() const { return this->ptr->dc; } void set_near_node(const std::string& near_node) { this->ptr->near = near_node; } std::string get_near_node() const { return this->ptr->near; } void set_filter_expr(const std::string& filter_expr) { this->ptr->filter = filter_expr; } std::string get_filter_expr() const { return this->ptr->filter; } // blocking query wait, limited to 10 minutes, default:5m, unit:ms void set_wait_ttl(int wait_ttl) { this->ptr->wait_ttl = wait_ttl; } int get_wait_ttl() const { return this->ptr->wait_ttl; } // enable blocking query void set_blocking_query(bool enable_flag) { this->ptr->blocking_query = enable_flag; } bool blocking_query() const { return this->ptr->blocking_query; } // only get health passing status service instance void set_passing(bool passing) { this->ptr->passing = passing; } bool get_passing() const { return this->ptr->passing; } // register config void set_replace_checks(bool replace_checks) { this->ptr->replace_checks = replace_checks; } bool get_replace_checks() const { return this->ptr->replace_checks; } void set_check_name(const std::string& check_name) { this->ptr->check_cfg.check_name = check_name; } std::string get_check_name() const { return this->ptr->check_cfg.check_name; } void set_check_http_url(const std::string& http_url) { this->ptr->check_cfg.http_url = http_url; } std::string get_check_http_url() const { return this->ptr->check_cfg.http_url; } void set_check_http_method(const std::string& method) { this->ptr->check_cfg.http_method = method; } std::string get_check_http_method() const { return this->ptr->check_cfg.http_method; } void add_http_header(const std::string& key, const std::vector& values) { this->ptr->check_cfg.headers.emplace(key, values); } const std::map> *get_http_headers() const { return &this->ptr->check_cfg.headers; } void set_http_body(const std::string& body) { this->ptr->check_cfg.http_body = body; } std::string get_http_body() const { return this->ptr->check_cfg.http_body; } void set_check_interval(int interval) { this->ptr->check_cfg.interval = interval; } int get_check_interval() const { return this->ptr->check_cfg.interval; } void set_check_timeout(int timeout) { this->ptr->check_cfg.timeout = timeout; } int get_check_timeout() const { return this->ptr->check_cfg.timeout; } void set_check_notes(const std::string& notes) { this->ptr->check_cfg.notes = notes; } std::string get_check_notes() const { return this->ptr->check_cfg.notes; } void set_check_tcp(const std::string& tcp_address) { this->ptr->check_cfg.tcp_address = tcp_address; } std::string get_check_tcp() const { return this->ptr->check_cfg.tcp_address; } void set_initial_status(const std::string& initial_status) { this->ptr->check_cfg.initial_status = initial_status; } std::string get_initial_status() const { return this->ptr->check_cfg.initial_status; } void set_auto_deregister_time(int milliseconds) { this->ptr->check_cfg.auto_deregister_time = milliseconds; } int get_auto_deregister_time() const { return this->ptr->check_cfg.auto_deregister_time; } // set success times before passing, refer to success_before_passing, default:0 void set_success_times(int times) { this->ptr->check_cfg.success_times = times; } int get_success_times() const { return this->ptr->check_cfg.success_times; } // set failure times before critical, refer to failures_before_critical, default:0 void set_failure_times(int times) { this->ptr->check_cfg.failure_times = times; } int get_failure_times() const { return this->ptr->check_cfg.failure_times; } void set_health_check(bool enable_flag) { this->ptr->check_cfg.health_check = enable_flag; } bool get_health_check() const { return this->ptr->check_cfg.health_check; } public: ConsulConfig() { this->ptr = new Config; this->ptr->blocking_query = false; this->ptr->passing = false; this->ptr->replace_checks = false; this->ptr->wait_ttl = 300 * 1000; this->ptr->check_cfg.interval = 5000; this->ptr->check_cfg.timeout = 10000; this->ptr->check_cfg.http_method = "GET"; this->ptr->check_cfg.initial_status = "critical"; this->ptr->check_cfg.auto_deregister_time = 10 * 60 * 1000; this->ptr->check_cfg.success_times = 0; this->ptr->check_cfg.failure_times = 0; this->ptr->check_cfg.health_check = false; this->ref = new std::atomic(1); } virtual ~ConsulConfig() { if (--*this->ref == 0) { delete this->ptr; delete this->ref; } } ConsulConfig(ConsulConfig&& move) { this->ptr = move.ptr; this->ref = move.ref; move.ptr = new Config; move.ref = new std::atomic(1); } ConsulConfig(const ConsulConfig& copy) { this->ptr = copy.ptr; this->ref = copy.ref; ++(*this->ref); } ConsulConfig& operator= (ConsulConfig&& move) { if (this != &move) { this->~ConsulConfig(); this->ptr = move.ptr; this->ref = move.ref; move.ptr = new Config; move.ref = new std::atomic(1); } return *this; } ConsulConfig& operator= (const ConsulConfig& copy) { if (this != ©) { this->~ConsulConfig(); this->ptr = copy.ptr; this->ref = copy.ref; ++(*this->ref); } return *this; } private: // register health check config struct HealthCheckConfig { std::string check_name; std::string notes; std::string http_url; std::string http_method; std::string http_body; std::string tcp_address; std::string initial_status; // passing or critical, default:critical std::map> headers; int auto_deregister_time; // refer to deregister_critical_service_after int interval; int timeout; // default 10000 int success_times; // default:0 success times before passing int failure_times; // default:0 failure_before_critical bool health_check; }; struct Config { // common config std::string token; // discover config std::string dc; std::string near; std::string filter; int wait_ttl; bool blocking_query; bool passing; // register config bool replace_checks; //refer to replace_existing_checks HealthCheckConfig check_cfg; }; private: struct Config *ptr; std::atomic *ref; }; // k:address, v:port using ConsulAddress = std::pair; struct ConsulService { std::string service_name; std::string service_namespace; std::string service_id; std::vector tags; ConsulAddress service_address; ConsulAddress lan; ConsulAddress lan_ipv4; ConsulAddress lan_ipv6; ConsulAddress virtual_address; ConsulAddress wan; ConsulAddress wan_ipv4; ConsulAddress wan_ipv6; std::map meta; bool tag_override; }; struct ConsulServiceInstance { // node info std::string node_id; std::string node_name; std::string node_address; std::string dc; std::map node_meta; long long create_index; long long modify_index; // service info struct ConsulService service; // service health check std::string check_name; std::string check_id; std::string check_notes; std::string check_output; std::string check_status; std::string check_type; }; struct ConsulServiceTags { std::string service_name; std::vector tags; }; } #endif workflow-0.11.8/src/protocol/DnsMessage.cc000066400000000000000000000162621476003635400204620ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Liu Kai (liukaidx@sogou-inc.com) */ #include #include #include "DnsMessage.h" #define DNS_LABELS_MAX 63 #define DNS_MESSAGE_MAX_UDP_SIZE 512 namespace protocol { static inline void __append_uint8(std::string& s, uint8_t tmp) { s.append((const char *)&tmp, sizeof (uint8_t)); } static inline void __append_uint16(std::string& s, uint16_t tmp) { tmp = htons(tmp); s.append((const char *)&tmp, sizeof (uint16_t)); } static inline void __append_uint32(std::string& s, uint32_t tmp) { tmp = htonl(tmp); s.append((const char *)&tmp, sizeof (uint32_t)); } static inline int __append_name(std::string& s, const char *p) { const char *name; size_t len; while (*p) { name = p; while (*p && *p != '.') p++; len = p - name; if (len > DNS_LABELS_MAX || (len == 0 && *p && *(p + 1))) { errno = EINVAL; return -1; } if (len > 0) { __append_uint8(s, len); s.append(name, len); } if (*p == '.') p++; } len = 0; __append_uint8(s, len); return 0; } static inline int __append_record_list(std::string& s, int *count, dns_record_cursor_t *cursor) { int cnt = 0; struct dns_record *record; std::string record_buf; std::string rdata_buf; int ret; while (dns_record_cursor_next(&record, cursor) == 0) { record_buf.clear(); ret = __append_name(record_buf, record->name); if (ret < 0) return ret; __append_uint16(record_buf, record->type); __append_uint16(record_buf, record->rclass); __append_uint32(record_buf, record->ttl); switch (record->type) { default: // encode unknown types as raw record case DNS_TYPE_A: case DNS_TYPE_AAAA: __append_uint16(record_buf, record->rdlength); record_buf.append((const char *)record->rdata, record->rdlength); break; case DNS_TYPE_NS: case DNS_TYPE_CNAME: case DNS_TYPE_PTR: rdata_buf.clear(); ret = __append_name(rdata_buf, (const char *)record->rdata); if (ret < 0) return ret; __append_uint16(record_buf, rdata_buf.size()); record_buf.append(rdata_buf); break; case DNS_TYPE_SOA: { auto *soa = (struct dns_record_soa *)record->rdata; rdata_buf.clear(); ret = __append_name(rdata_buf, soa->mname); if (ret < 0) return ret; ret = __append_name(rdata_buf, soa->rname); if (ret < 0) return ret; __append_uint32(rdata_buf, soa->serial); __append_uint32(rdata_buf, soa->refresh); __append_uint32(rdata_buf, soa->retry); __append_uint32(rdata_buf, soa->expire); __append_uint32(rdata_buf, soa->minimum); __append_uint16(record_buf, rdata_buf.size()); record_buf.append(rdata_buf); break; } case DNS_TYPE_SRV: { auto *srv = (struct dns_record_srv *)record->rdata; rdata_buf.clear(); __append_uint16(rdata_buf, srv->priority); __append_uint16(rdata_buf, srv->weight); __append_uint16(rdata_buf, srv->port); ret = __append_name(rdata_buf, srv->target); if (ret < 0) return ret; __append_uint16(record_buf, rdata_buf.size()); record_buf.append(rdata_buf); break; } case DNS_TYPE_MX: { auto *mx = (struct dns_record_mx *)record->rdata; rdata_buf.clear(); __append_uint16(rdata_buf, mx->preference); ret = __append_name(rdata_buf, mx->exchange); if (ret < 0) return ret; __append_uint16(record_buf, rdata_buf.size()); record_buf.append(rdata_buf); break; } } cnt++; s.append(record_buf); } if (count) *count = cnt; return 0; } DnsMessage::DnsMessage(DnsMessage&& msg) : ProtocolMessage(std::move(msg)) { this->parser = msg.parser; msg.parser = NULL; this->cur_size = msg.cur_size; msg.cur_size = 0; } DnsMessage& DnsMessage::operator = (DnsMessage&& msg) { if (&msg != this) { *(ProtocolMessage *)this = std::move(msg); if (this->parser) { dns_parser_deinit(this->parser); delete this->parser; } this->parser = msg.parser; msg.parser = NULL; this->cur_size = msg.cur_size; msg.cur_size = 0; } return *this; } int DnsMessage::encode_reply() { dns_record_cursor_t cursor; struct dns_header h; std::string tmpbuf; const char *p; int ancount; int nscount; int arcount; int ret; msgbuf.clear(); msgsize = 0; // TODO // this is an incomplete and inefficient way, compress not used, // pointers can only be used for occurances of a domain name where // the format is not class specific dns_answer_cursor_init(&cursor, this->parser); ret = __append_record_list(tmpbuf, &ancount, &cursor); dns_record_cursor_deinit(&cursor); if (ret < 0) return ret; dns_authority_cursor_init(&cursor, this->parser); ret = __append_record_list(tmpbuf, &nscount, &cursor); dns_record_cursor_deinit(&cursor); if (ret < 0) return ret; dns_additional_cursor_init(&cursor, this->parser); ret = __append_record_list(tmpbuf, &arcount, &cursor); dns_record_cursor_deinit(&cursor); if (ret < 0) return ret; h = this->parser->header; h.id = htons(h.id); h.qdcount = htons(1); h.ancount = htons(ancount); h.nscount = htons(nscount); h.arcount = htons(arcount); msgbuf.append((const char *)&h, sizeof (struct dns_header)); p = parser->question.qname ? parser->question.qname : "."; ret = __append_name(msgbuf, p); if (ret < 0) return ret; __append_uint16(msgbuf, parser->question.qtype); __append_uint16(msgbuf, parser->question.qclass); msgbuf.append(tmpbuf); if (msgbuf.size() >= (1 << 16)) { errno = EOVERFLOW; return -1; } msgsize = htons(msgbuf.size()); return 0; } int DnsMessage::encode(struct iovec vectors[], int) { struct iovec *p = vectors; if (this->encode_reply() < 0) return -1; // TODO // if this is a request, it won't exceed the 512 bytes UDP limit // if this is a response and exceed 512 bytes, we need a TrunCation reply if (!this->is_single_packet()) { p->iov_base = &this->msgsize; p->iov_len = sizeof (uint16_t); p++; } p->iov_base = (void *)this->msgbuf.data(); p->iov_len = msgbuf.size(); return p - vectors + 1; } int DnsMessage::append(const void *buf, size_t *size) { int ret = dns_parser_append_message(buf, size, this->parser); if (ret >= 0) { this->cur_size += *size; if (this->cur_size > this->size_limit) { errno = EMSGSIZE; ret = -1; } } else if (ret == -2) { errno = EBADMSG; ret = -1; } return ret; } int DnsResponse::append(const void *buf, size_t *size) { int ret = this->DnsMessage::append(buf, size); const char *qname = this->parser->question.qname; if (ret >= 1 && (this->request_id != this->get_id() || strcasecmp(this->request_name.c_str(), qname) != 0)) { if (!this->is_single_packet()) { errno = EBADMSG; ret = -1; } else { dns_parser_deinit(this->parser); dns_parser_init(this->parser); ret = 0; } } return ret; } } workflow-0.11.8/src/protocol/DnsMessage.h000066400000000000000000000165771476003635400203350ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Liu Kai (liukaidx@sogou-inc.com) */ #ifndef _DNSMESSAGE_H_ #define _DNSMESSAGE_H_ /** * @file DnsMessage.h * @brief Dns Protocol Interface */ #include #include #include "ProtocolMessage.h" #include "dns_parser.h" namespace protocol { class DnsMessage : public ProtocolMessage { protected: virtual int encode(struct iovec vectors[], int max); virtual int append(const void *buf, size_t *size); public: int get_id() const { return parser->header.id; } int get_qr() const { return parser->header.qr; } int get_opcode() const { return parser->header.opcode; } int get_aa() const { return parser->header.aa; } int get_tc() const { return parser->header.tc; } int get_rd() const { return parser->header.rd; } int get_ra() const { return parser->header.ra; } int get_rcode() const { return parser->header.rcode; } int get_qdcount() const { return parser->header.qdcount; } int get_ancount() const { return parser->header.ancount; } int get_nscount() const { return parser->header.nscount; } int get_arcount() const { return parser->header.arcount; } void set_id(int id) { parser->header.id = id; } void set_qr(int qr) { parser->header.qr = qr; } void set_opcode(int opcode) { parser->header.opcode = opcode; } void set_aa(int aa) { parser->header.aa = aa; } void set_tc(int tc) { parser->header.tc = tc; } void set_rd(int rd) { parser->header.rd = rd; } void set_ra(int ra) { parser->header.ra = ra; } void set_rcode(int rcode) { parser->header.rcode = rcode; } int get_question_type() const { return parser->question.qtype; } int get_question_class() const { return parser->question.qclass; } std::string get_question_name() const { const char *name = parser->question.qname; if (name == NULL) return ""; return name; } void set_question_type(int qtype) { parser->question.qtype = qtype; } void set_question_class(int qclass) { parser->question.qclass = qclass; } void set_question_name(const std::string& name) { char *pname = parser->question.qname; if (pname != NULL) free(pname); parser->question.qname = strdup(name.c_str()); } int add_a_record(int section, const char *name, uint16_t rclass, uint32_t ttl, const void *data) { struct list_head *list = get_section(section); if (!list) return -1; return dns_add_raw_record(name, DNS_TYPE_A, rclass, ttl, 4, data, list); } int add_aaaa_record(int section, const char *name, uint16_t rclass, uint32_t ttl, const void *data) { struct list_head *list = get_section(section); if (!list) return -1; return dns_add_raw_record(name, DNS_TYPE_AAAA, rclass, ttl, 16, data, list); } int add_ns_record(int section, const char *name, uint16_t rclass, uint32_t ttl, const char *data) { struct list_head *list = get_section(section); if (!list) return -1; return dns_add_str_record(name, DNS_TYPE_NS, rclass, ttl, data, list); } int add_cname_record(int section, const char *name, uint16_t rclass, uint32_t ttl, const char *data) { struct list_head *list = get_section(section); if (!list) return -1; return dns_add_str_record(name, DNS_TYPE_CNAME, rclass, ttl, data, list); } int add_ptr_record(int section, const char *name, uint16_t rclass, uint32_t ttl, const char *data) { struct list_head *list = get_section(section); if (!list) return -1; return dns_add_str_record(name, DNS_TYPE_PTR, rclass, ttl, data, list); } int add_soa_record(int section, const char *name, uint16_t rclass, uint32_t ttl, const char *mname, const char *rname, uint32_t serial, int32_t refresh, int32_t retry, int32_t expire, uint32_t minimum) { struct list_head *list = get_section(section); if (!list) return -1; return dns_add_soa_record(name, rclass, ttl, mname, rname, serial, refresh, retry, expire, minimum, list); } int add_srv_record(int section, const char *name, uint16_t rclass, uint32_t ttl, uint16_t priority, uint16_t weight, uint16_t port, const char *target) { struct list_head *list = get_section(section); if (!list) return -1; return dns_add_srv_record(name, rclass, ttl, priority, weight, port, target, list); } int add_mx_record(int section, const char *name, uint16_t rclass, uint32_t ttl, int16_t preference, const char *exchange) { struct list_head *list = get_section(section); if (!list) return -1; return dns_add_mx_record(name, rclass, ttl, preference, exchange, list); } int add_raw_record(int section, const char *name, uint16_t type, uint16_t rclass, uint32_t ttl, const void *data, uint16_t dlen) { struct list_head *list = get_section(section); if (!list) return -1; return dns_add_raw_record(name, type, rclass, ttl, dlen, data, list); } // Inner use only bool is_single_packet() const { return parser->single_packet; } void set_single_packet(bool single) { parser->single_packet = single; } public: DnsMessage() : parser(new dns_parser_t) { dns_parser_init(parser); this->cur_size = 0; } virtual ~DnsMessage() { if (this->parser) { dns_parser_deinit(parser); delete this->parser; } } DnsMessage(DnsMessage&& msg); DnsMessage& operator = (DnsMessage&& msg); protected: dns_parser_t *parser; std::string msgbuf; size_t cur_size; private: int encode_reply(); int encode_truncation_reply(); struct list_head *get_section(int section) { switch (section) { case DNS_ANSWER_SECTION: return &(parser->answer_list); case DNS_AUTHORITY_SECTION: return &(parser->authority_list); case DNS_ADDITIONAL_SECTION: return &(parser->additional_list); default: errno = EINVAL; return NULL; } } // size of msgbuf, but in network byte order uint16_t msgsize; }; class DnsRequest : public DnsMessage { public: DnsRequest() = default; DnsRequest(DnsRequest&& req) = default; DnsRequest& operator = (DnsRequest&& req) = default; void set_question(const char *host, uint16_t qtype, uint16_t qclass) { dns_parser_set_question(host, qtype, qclass, this->parser); } }; class DnsResponse : public DnsMessage { public: DnsResponse() { this->request_id = 0; } DnsResponse(DnsResponse&& req) = default; DnsResponse& operator = (DnsResponse&& req) = default; const dns_parser_t *get_parser() const { return this->parser; } void set_request_id(uint16_t id) { this->request_id = id; } void set_request_name(const std::string& name) { std::string& req_name = this->request_name; req_name = name; while (req_name.size() > 1 && req_name.back() == '.') req_name.pop_back(); } protected: virtual int append(const void *buf, size_t *size); private: uint16_t request_id; std::string request_name; }; } #endif workflow-0.11.8/src/protocol/DnsUtil.cc000066400000000000000000000061561476003635400200140ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Liu Kai (liukaidx@sogou-inc.com) */ #include #include #include #include #include #include "DnsUtil.h" namespace protocol { int DnsUtil::getaddrinfo(const DnsResponse *resp, unsigned short port, struct addrinfo **addrinfo) { int ancount = resp->get_ancount(); int rcode = resp->get_rcode(); int status = 0; struct addrinfo *res = NULL; struct addrinfo **pres = &res; struct dns_record *record; struct addrinfo *ai; std::string qname; const char *cname; int family; int addrlen; switch (rcode) { case DNS_RCODE_NAME_ERROR: status = EAI_NONAME; break; case DNS_RCODE_SERVER_FAILURE: status = EAI_AGAIN; break; case DNS_RCODE_FORMAT_ERROR: case DNS_RCODE_NOT_IMPLEMENTED: case DNS_RCODE_REFUSED: status = EAI_FAIL; break; } qname = resp->get_question_name(); cname = qname.c_str(); DnsResultCursor cursor(resp); cursor.reset_answer_cursor(); /* Forbid loop in cname chain */ while (cursor.find_cname(cname, &cname) && ancount-- > 0) { } if (rcode == DNS_RCODE_NO_ERROR && ancount <= 0) status = EAI_NODATA; if (status != 0) return status; cursor.reset_answer_cursor(); while (cursor.next(&record)) { if (!(record->rclass == DNS_CLASS_IN && (record->type == DNS_TYPE_A || record->type == DNS_TYPE_AAAA) && strcasecmp(record->name, cname) == 0)) continue; if (record->type == DNS_TYPE_A) { family = AF_INET; addrlen = sizeof (struct sockaddr_in); } else { family = AF_INET6; addrlen = sizeof (struct sockaddr_in6); } ai = (struct addrinfo *)calloc(sizeof (struct addrinfo) + addrlen, 1); if (ai == NULL) { if (res) DnsUtil::freeaddrinfo(res); return EAI_SYSTEM; } ai->ai_family = family; ai->ai_addrlen = addrlen; ai->ai_addr = (struct sockaddr *)(ai + 1); ai->ai_addr->sa_family = family; if (family == AF_INET) { struct sockaddr_in *in = (struct sockaddr_in *)(ai->ai_addr); in->sin_port = htons(port); memcpy(&in->sin_addr, record->rdata, sizeof (struct in_addr)); } else { struct sockaddr_in6 *in = (struct sockaddr_in6 *)(ai->ai_addr); in->sin6_port = htons(port); memcpy(&in->sin6_addr, record->rdata, sizeof (struct in6_addr)); } *pres = ai; pres = &ai->ai_next; } if (res == NULL) return EAI_NODATA; if (cname) res->ai_canonname = strdup(cname); *addrinfo = res; return 0; } void DnsUtil::freeaddrinfo(struct addrinfo *ai) { struct addrinfo *p; while (ai != NULL) { p = ai; ai = ai->ai_next; free(p->ai_canonname); free(p); } } } workflow-0.11.8/src/protocol/DnsUtil.h000066400000000000000000000036641476003635400176570ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Liu Kai (liukaidx@sogou-inc.com) */ #ifndef _DNSUTIL_H_ #define _DNSUTIL_H_ #include #include "DnsMessage.h" /** * @file DnsUtil.h * @brief Dns toolbox */ namespace protocol { class DnsUtil { public: static int getaddrinfo(const DnsResponse *resp, unsigned short port, struct addrinfo **res); static void freeaddrinfo(struct addrinfo *ai); }; class DnsResultCursor { public: DnsResultCursor(const DnsResponse *resp) : parser(resp->get_parser()) { dns_answer_cursor_init(&cursor, parser); record = NULL; } DnsResultCursor(DnsResultCursor&& move) = delete; DnsResultCursor& operator=(DnsResultCursor&& move) = delete; virtual ~DnsResultCursor() { } void reset_answer_cursor() { dns_answer_cursor_init(&cursor, parser); } void reset_authority_cursor() { dns_authority_cursor_init(&cursor, parser); } void reset_additional_cursor() { dns_additional_cursor_init(&cursor, parser); } bool next(struct dns_record **next_record) { int ret = dns_record_cursor_next(&record, &cursor); if (ret != 0) record = NULL; else *next_record = record; return ret == 0; } bool find_cname(const char *name, const char **cname) { return dns_record_cursor_find_cname(name, cname, &cursor) == 0; } private: const dns_parser_t *parser; dns_record_cursor_t cursor; struct dns_record *record; }; } #endif workflow-0.11.8/src/protocol/HttpMessage.cc000066400000000000000000000210771476003635400206550ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include "HttpMessage.h" namespace protocol { struct HttpMessageBlock { struct list_head list; const void *ptr; size_t size; }; bool HttpMessage::append_output_body(const void *buf, size_t size) { size_t n = sizeof (struct HttpMessageBlock) + size; struct HttpMessageBlock *block = (struct HttpMessageBlock *)malloc(n); if (block) { memcpy(block + 1, buf, size); block->ptr = block + 1; block->size = size; list_add_tail(&block->list, &this->output_body); this->output_body_size += size; return true; } return false; } bool HttpMessage::append_output_body_nocopy(const void *buf, size_t size) { size_t n = sizeof (struct HttpMessageBlock); struct HttpMessageBlock *block = (struct HttpMessageBlock *)malloc(n); if (block) { block->ptr = buf; block->size = size; list_add_tail(&block->list, &this->output_body); this->output_body_size += size; return true; } return false; } size_t HttpMessage::get_output_body_blocks(const void *buf[], size_t size[], size_t max) const { struct HttpMessageBlock *block; struct list_head *pos; size_t n = 0; list_for_each(pos, &this->output_body) { if (n == max) break; block = list_entry(pos, struct HttpMessageBlock, list); buf[n] = block->ptr; size[n] = block->size; n++; } return n; } bool HttpMessage::get_output_body_merged(void *buf, size_t *size) const { struct HttpMessageBlock *block; struct list_head *pos; if (*size < this->output_body_size) { errno = ENOSPC; return false; } list_for_each(pos, &this->output_body) { block = list_entry(pos, struct HttpMessageBlock, list); memcpy(buf, block->ptr, block->size); buf = (char *)buf + block->size; } *size = this->output_body_size; return true; } void HttpMessage::clear_output_body() { struct HttpMessageBlock *block; struct list_head *pos, *tmp; list_for_each_safe(pos, tmp, &this->output_body) { block = list_entry(pos, struct HttpMessageBlock, list); list_del(pos); free(block); } this->output_body_size = 0; } struct list_head *HttpMessage::combine_from(struct list_head *pos, size_t size) { size_t n = sizeof (struct HttpMessageBlock) + size; struct HttpMessageBlock *block = (struct HttpMessageBlock *)malloc(n); struct HttpMessageBlock *entry; char *ptr; if (block) { block->ptr = block + 1; block->size = size; ptr = (char *)block->ptr; do { entry = list_entry(pos, struct HttpMessageBlock, list); pos = pos->next; list_del(&entry->list); memcpy(ptr, entry->ptr, entry->size); ptr += entry->size; free(entry); } while (pos != &this->output_body); list_add_tail(&block->list, &this->output_body); return &block->list; } return NULL; } int HttpMessage::encode(struct iovec vectors[], int max) { const char *start_line[3]; http_header_cursor_t cursor; struct HttpMessageHeader header; struct HttpMessageBlock *block; struct list_head *pos; size_t size; int i; start_line[0] = http_parser_get_method(this->parser); if (start_line[0]) { start_line[1] = http_parser_get_uri(this->parser); start_line[2] = http_parser_get_version(this->parser); } else { start_line[0] = http_parser_get_version(this->parser); start_line[1] = http_parser_get_code(this->parser); start_line[2] = http_parser_get_phrase(this->parser); } if (!start_line[0] || !start_line[1] || !start_line[2]) { errno = EBADMSG; return -1; } vectors[0].iov_base = (void *)start_line[0]; vectors[0].iov_len = strlen(start_line[0]); vectors[1].iov_base = (void *)" "; vectors[1].iov_len = 1; vectors[2].iov_base = (void *)start_line[1]; vectors[2].iov_len = strlen(start_line[1]); vectors[3].iov_base = (void *)" "; vectors[3].iov_len = 1; vectors[4].iov_base = (void *)start_line[2]; vectors[4].iov_len = strlen(start_line[2]); vectors[5].iov_base = (void *)"\r\n"; vectors[5].iov_len = 2; i = 6; http_header_cursor_init(&cursor, this->parser); while (http_header_cursor_next(&header.name, &header.name_len, &header.value, &header.value_len, &cursor) == 0) { if (i == max) break; vectors[i].iov_base = (void *)header.name; vectors[i].iov_len = header.name_len + 2 + header.value_len + 2; i++; } http_header_cursor_deinit(&cursor); if (i + 1 >= max) { errno = EOVERFLOW; return -1; } vectors[i].iov_base = (void *)"\r\n"; vectors[i].iov_len = 2; i++; size = this->output_body_size; list_for_each(pos, &this->output_body) { if (i + 1 == max && pos != this->output_body.prev) { pos = this->combine_from(pos, size); if (!pos) return -1; } block = list_entry(pos, struct HttpMessageBlock, list); vectors[i].iov_base = (void *)block->ptr; vectors[i].iov_len = block->size; size -= block->size; i++; } return i; } inline int HttpMessage::append(const void *buf, size_t *size) { int ret = http_parser_append_message(buf, size, this->parser); if (ret >= 0) { this->cur_size += *size; if (this->cur_size > this->size_limit) { errno = EMSGSIZE; ret = -1; } } else if (ret == -2) { errno = EBADMSG; ret = -1; } return ret; } HttpMessage::HttpMessage(HttpMessage&& msg) : ProtocolMessage(std::move(msg)) { this->parser = msg.parser; msg.parser = NULL; INIT_LIST_HEAD(&this->output_body); list_splice_init(&msg.output_body, &this->output_body); this->output_body_size = msg.output_body_size; msg.output_body_size = 0; this->cur_size = msg.cur_size; msg.cur_size = 0; } HttpMessage& HttpMessage::operator = (HttpMessage&& msg) { if (&msg != this) { *(ProtocolMessage *)this = std::move(msg); if (this->parser) { http_parser_deinit(this->parser); delete this->parser; } this->parser = msg.parser; msg.parser = NULL; this->clear_output_body(); list_splice_init(&msg.output_body, &this->output_body); this->output_body_size = msg.output_body_size; msg.output_body_size = 0; this->cur_size = msg.cur_size; msg.cur_size = 0; } return *this; } #define HTTP_100_STATUS_LINE "HTTP/1.1 100 Continue" #define HTTP_400_STATUS_LINE "HTTP/1.1 400 Bad Request" #define HTTP_413_STATUS_LINE "HTTP/1.1 413 Request Entity Too Large" #define HTTP_417_STATUS_LINE "HTTP/1.1 417 Expectation Failed" #define CONTENT_LENGTH_ZERO "Content-Length: 0" #define CONNECTION_CLOSE "Connection: close" #define CRLF "\r\n" #define HTTP_100_RESP HTTP_100_STATUS_LINE CRLF \ CRLF #define HTTP_400_RESP HTTP_400_STATUS_LINE CRLF \ CONTENT_LENGTH_ZERO CRLF \ CONNECTION_CLOSE CRLF \ CRLF #define HTTP_413_RESP HTTP_413_STATUS_LINE CRLF \ CONTENT_LENGTH_ZERO CRLF \ CONNECTION_CLOSE CRLF \ CRLF #define HTTP_417_RESP HTTP_417_STATUS_LINE CRLF \ CONTENT_LENGTH_ZERO CRLF \ CONNECTION_CLOSE CRLF \ CRLF int HttpRequest::handle_expect_continue() { size_t trans_len = this->parser->transfer_length; int ret; if (trans_len != (size_t)-1) { if (this->parser->header_offset + trans_len > this->size_limit) { this->feedback(HTTP_417_RESP, strlen(HTTP_417_RESP)); errno = EMSGSIZE; return -1; } } ret = this->feedback(HTTP_100_RESP, strlen(HTTP_100_RESP)); if (ret != strlen(HTTP_100_RESP)) { if (ret >= 0) errno = ENOBUFS; return -1; } return 0; } int HttpRequest::append(const void *buf, size_t *size) { int ret = HttpMessage::append(buf, size); if (ret == 0) { if (this->parser->expect_continue && http_parser_header_complete(this->parser)) { this->parser->expect_continue = 0; ret = this->handle_expect_continue(); } } else if (ret < 0) { if (errno == EBADMSG) this->feedback(HTTP_400_RESP, strlen(HTTP_400_RESP)); else if (errno == EMSGSIZE) this->feedback(HTTP_413_RESP, strlen(HTTP_413_RESP)); } return ret; } int HttpResponse::append(const void *buf, size_t *size) { int ret = HttpMessage::append(buf, size); if (ret > 0) { if (strcmp(http_parser_get_code(this->parser), "100") == 0) { http_parser_deinit(this->parser); http_parser_init(1, this->parser); ret = 0; } } return ret; } } workflow-0.11.8/src/protocol/HttpMessage.h000066400000000000000000000202161476003635400205110ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _HTTPMESSAGE_H_ #define _HTTPMESSAGE_H_ #include #include #include #include "list.h" #include "ProtocolMessage.h" #include "http_parser.h" /** * @file HttpMessage.h * @brief Http Protocol Interface */ namespace protocol { struct HttpMessageHeader { const void *name; size_t name_len; const void *value; size_t value_len; }; class HttpMessage : public ProtocolMessage { public: const char *get_http_version() const { return http_parser_get_version(this->parser); } bool set_http_version(const char *version) { return http_parser_set_version(version, this->parser) == 0; } bool is_chunked() const { return http_parser_chunked(this->parser); } bool is_keep_alive() const { return http_parser_keep_alive(this->parser); } bool add_header(const struct HttpMessageHeader *header) { return http_parser_add_header(header->name, header->name_len, header->value, header->value_len, this->parser) == 0; } bool add_header_pair(const char *name, const char *value) { return http_parser_add_header(name, strlen(name), value, strlen(value), this->parser) == 0; } bool set_header(const struct HttpMessageHeader *header) { return http_parser_set_header(header->name, header->name_len, header->value, header->value_len, this->parser) == 0; } bool set_header_pair(const char *name, const char *value) { return http_parser_set_header(name, strlen(name), value, strlen(value), this->parser) == 0; } bool get_parsed_body(const void **body, size_t *size) const { return http_parser_get_body(body, size, this->parser) == 0; } /* Output body is for sending. Want to transfer a message received, maybe: * msg->get_parsed_body(&body, &size); * msg->append_output_body_nocopy(body, size); */ bool append_output_body(const void *buf, size_t size); bool append_output_body(const char *buf) { return this->append_output_body(buf, strlen(buf)); } bool append_output_body_nocopy(const void *buf, size_t size); bool append_output_body_nocopy(const char *buf) { return this->append_output_body_nocopy(buf, strlen(buf)); } size_t get_output_body_size() const { return this->output_body_size; } size_t get_output_body_blocks(const void *buf[], size_t size[], size_t max) const; bool get_output_body_merged(void *buf, size_t *size) const; void clear_output_body(); /* std::string interfaces */ public: bool get_http_version(std::string& version) const { const char *str = this->get_http_version(); if (str) { version.assign(str); return true; } return false; } bool set_http_version(const std::string& version) { return this->set_http_version(version.c_str()); } bool add_header_pair(const std::string& name, const std::string& value) { return http_parser_add_header(name.c_str(), name.size(), value.c_str(), value.size(), this->parser) == 0; } bool set_header_pair(const std::string& name, const std::string& value) { return http_parser_set_header(name.c_str(), name.size(), value.c_str(), value.size(), this->parser) == 0; } bool append_output_body(const std::string& buf) { return this->append_output_body(buf.c_str(), buf.size()); } bool append_output_body_nocopy(const std::string& buf) { return this->append_output_body_nocopy(buf.c_str(), buf.size()); } bool get_output_body_merged(std::string& body) const { size_t size = this->output_body_size; body.resize(size); return this->get_output_body_merged((void *)body.data(), &size); } /* for http task implementations. */ public: bool is_header_complete() const { return http_parser_header_complete(this->parser); } bool has_connection_header() const { return http_parser_has_connection(this->parser); } bool has_content_length_header() const { return http_parser_has_content_length(this->parser); } bool has_keep_alive_header() const { return http_parser_has_keep_alive(this->parser); } void end_parsing() { http_parser_close_message(this->parser); } /* for header cursor implementations. */ const http_parser_t *get_parser() const { return this->parser; } protected: virtual int encode(struct iovec vectors[], int max); virtual int append(const void *buf, size_t *size); protected: http_parser_t *parser; size_t cur_size; private: struct list_head *combine_from(struct list_head *pos, size_t size); private: struct list_head output_body; size_t output_body_size; public: HttpMessage(bool is_resp) : parser(new http_parser_t) { http_parser_init(is_resp, this->parser); INIT_LIST_HEAD(&this->output_body); this->output_body_size = 0; this->cur_size = 0; } virtual ~HttpMessage() { this->clear_output_body(); if (this->parser) { http_parser_deinit(this->parser); delete this->parser; } } public: HttpMessage(HttpMessage&& msg); HttpMessage& operator = (HttpMessage&& msg); }; class HttpRequest : public HttpMessage { public: const char *get_method() const { return http_parser_get_method(this->parser); } const char *get_request_uri() const { return http_parser_get_uri(this->parser); } bool set_method(const char *method) { return http_parser_set_method(method, this->parser) == 0; } bool set_request_uri(const char *uri) { return http_parser_set_uri(uri, this->parser) == 0; } /* std::string interfaces */ public: bool get_method(std::string& method) const { const char *str = this->get_method(); if (str) { method.assign(str); return true; } return false; } bool get_request_uri(std::string& uri) const { const char *str = this->get_request_uri(); if (str) { uri.assign(str); return true; } return false; } bool set_method(const std::string& method) { return this->set_method(method.c_str()); } bool set_request_uri(const std::string& uri) { return this->set_request_uri(uri.c_str()); } protected: virtual int append(const void *buf, size_t *size); private: int handle_expect_continue(); public: HttpRequest() : HttpMessage(false) { } public: HttpRequest(HttpRequest&& req) = default; HttpRequest& operator = (HttpRequest&& req) = default; }; class HttpResponse : public HttpMessage { public: const char *get_status_code() const { return http_parser_get_code(this->parser); } const char *get_reason_phrase() const { return http_parser_get_phrase(this->parser); } bool set_status_code(const char *code) { return http_parser_set_code(code, this->parser) == 0; } bool set_reason_phrase(const char *phrase) { return http_parser_set_phrase(phrase, this->parser) == 0; } /* std::string interfaces */ public: bool get_status_code(std::string& code) const { const char *str = this->get_status_code(); if (str) { code.assign(str); return true; } return false; } bool get_reason_phrase(std::string& phrase) const { const char *str = this->get_reason_phrase(); if (str) { phrase.assign(str); return true; } return false; } bool set_status_code(const std::string& code) { return this->set_status_code(code.c_str()); } bool set_reason_phrase(const std::string& phrase) { return this->set_reason_phrase(phrase.c_str()); } public: /* Tell the parser, it is a HEAD response. For implementations. */ void parse_zero_body() { this->parser->transfer_length = 0; } protected: virtual int append(const void *buf, size_t *size); public: HttpResponse() : HttpMessage(true) { } public: HttpResponse(HttpResponse&& resp) = default; HttpResponse& operator = (HttpResponse&& resp) = default; }; } #endif workflow-0.11.8/src/protocol/HttpUtil.cc000066400000000000000000000221401476003635400201760ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com) */ #include #include #include #include "http_parser.h" #include "HttpMessage.h" #include "HttpUtil.h" namespace protocol { HttpHeaderMap::HttpHeaderMap(const HttpMessage *message) { http_header_cursor_t cursor; struct HttpMessageHeader header; http_header_cursor_init(&cursor, message->get_parser()); while (http_header_cursor_next(&header.name, &header.name_len, &header.value, &header.value_len, &cursor) == 0) { std::string key((const char *)header.name, header.name_len); std::transform(key.begin(), key.end(), key.begin(), ::tolower); header_map_[key].emplace_back((const char *)header.value, header.value_len); } http_header_cursor_deinit(&cursor); } bool HttpHeaderMap::key_exists(std::string key) { std::transform(key.begin(), key.end(), key.begin(), ::tolower); return header_map_.count(key) > 0; } std::string HttpHeaderMap::get(std::string key) { std::transform(key.begin(), key.end(), key.begin(), ::tolower); const auto it = header_map_.find(key); if (it == header_map_.end() || it->second.empty()) return std::string(); return it->second[0]; } bool HttpHeaderMap::get(std::string key, std::string& value) { std::transform(key.begin(), key.end(), key.begin(), ::tolower); const auto it = header_map_.find(key); if (it == header_map_.end() || it->second.empty()) return false; value = it->second[0]; return true; } std::vector HttpHeaderMap::get_strict(std::string key) { std::transform(key.begin(), key.end(), key.begin(), ::tolower); return header_map_[key]; } bool HttpHeaderMap::get_strict(std::string key, std::vector& values) { std::transform(key.begin(), key.end(), key.begin(), ::tolower); const auto it = header_map_.find(key); if (it == header_map_.end() || it->second.empty()) return false; values = it->second; return true; } std::string HttpUtil::decode_chunked_body(const HttpMessage *msg) { const void *body; size_t body_len; const void *chunk; size_t chunk_size; std::string decode_result; HttpChunkCursor cursor(msg); if (msg->get_parsed_body(&body, &body_len)) { decode_result.reserve(body_len); while (cursor.next(&chunk, &chunk_size)) decode_result.append((const char *)chunk, chunk_size); } return decode_result; } void HttpUtil::set_response_status(HttpResponse *resp, int status_code) { char buf[32]; sprintf(buf, "%d", status_code); resp->set_status_code(buf); switch (status_code) { case 100: resp->set_reason_phrase("Continue"); break; case 101: resp->set_reason_phrase("Switching Protocols"); break; case 102: resp->set_reason_phrase("Processing"); break; case 200: resp->set_reason_phrase("OK"); break; case 201: resp->set_reason_phrase("Created"); break; case 202: resp->set_reason_phrase("Accepted"); break; case 203: resp->set_reason_phrase("Non-Authoritative Information"); break; case 204: resp->set_reason_phrase("No Content"); break; case 205: resp->set_reason_phrase("Reset Content"); break; case 206: resp->set_reason_phrase("Partial Content"); break; case 207: resp->set_reason_phrase("Multi-Status"); break; case 208: resp->set_reason_phrase("Already Reported"); break; case 226: resp->set_reason_phrase("IM Used"); break; case 300: resp->set_reason_phrase("Multiple Choices"); break; case 301: resp->set_reason_phrase("Moved Permanently"); break; case 302: resp->set_reason_phrase("Found"); break; case 303: resp->set_reason_phrase("See Other"); break; case 304: resp->set_reason_phrase("Not Modified"); break; case 305: resp->set_reason_phrase("Use Proxy"); break; case 306: resp->set_reason_phrase("Switch Proxy"); break; case 307: resp->set_reason_phrase("Temporary Redirect"); break; case 308: resp->set_reason_phrase("Permanent Redirect"); break; case 400: resp->set_reason_phrase("Bad Request"); break; case 401: resp->set_reason_phrase("Unauthorized"); break; case 402: resp->set_reason_phrase("Payment Required"); break; case 403: resp->set_reason_phrase("Forbidden"); break; case 404: resp->set_reason_phrase("Not Found"); break; case 405: resp->set_reason_phrase("Method Not Allowed"); break; case 406: resp->set_reason_phrase("Not Acceptable"); break; case 407: resp->set_reason_phrase("Proxy Authentication Required"); break; case 408: resp->set_reason_phrase("Request Timeout"); break; case 409: resp->set_reason_phrase("Conflict"); break; case 410: resp->set_reason_phrase("Gone"); break; case 411: resp->set_reason_phrase("Length Required"); break; case 412: resp->set_reason_phrase("Precondition Failed"); break; case 413: resp->set_reason_phrase("Request Entity Too Large"); break; case 414: resp->set_reason_phrase("Request-URI Too Long"); break; case 415: resp->set_reason_phrase("Unsupported Media Type"); break; case 416: resp->set_reason_phrase("Requested Range Not Satisfiable"); break; case 417: resp->set_reason_phrase("Expectation Failed"); break; case 418: resp->set_reason_phrase("I'm a teapot"); break; case 420: resp->set_reason_phrase("Enhance Your Caim"); break; case 421: resp->set_reason_phrase("Misdirected Request"); break; case 422: resp->set_reason_phrase("Unprocessable Entity"); break; case 423: resp->set_reason_phrase("Locked"); break; case 424: resp->set_reason_phrase("Failed Dependency"); break; case 425: resp->set_reason_phrase("Too Early"); break; case 426: resp->set_reason_phrase("Upgrade Required"); break; case 428: resp->set_reason_phrase("Precondition Required"); break; case 429: resp->set_reason_phrase("Too Many Requests"); break; case 431: resp->set_reason_phrase("Request Header Fields Too Large"); break; case 444: resp->set_reason_phrase("No Response"); break; case 450: resp->set_reason_phrase("Blocked by Windows Parental Controls"); break; case 451: resp->set_reason_phrase("Unavailable For Legal Reasons"); break; case 494: resp->set_reason_phrase("Request Header Too Large"); break; case 500: resp->set_reason_phrase("Internal Server Error"); break; case 501: resp->set_reason_phrase("Not Implemented"); break; case 502: resp->set_reason_phrase("Bad Gateway"); break; case 503: resp->set_reason_phrase("Service Unavailable"); break; case 504: resp->set_reason_phrase("Gateway Timeout"); break; case 505: resp->set_reason_phrase("HTTP Version Not Supported"); break; case 506: resp->set_reason_phrase("Variant Also Negotiates"); break; case 507: resp->set_reason_phrase("Insufficient Storage"); break; case 508: resp->set_reason_phrase("Loop Detected"); break; case 510: resp->set_reason_phrase("Not Extended"); break; case 511: resp->set_reason_phrase("Network Authentication Required"); break; default: resp->set_reason_phrase("Unknown"); break; } } bool HttpHeaderCursor::next(std::string& name, std::string& value) { struct HttpMessageHeader header; if (this->next(&header)) { name.assign((const char *)header.name, header.name_len); value.assign((const char *)header.value, header.value_len); return true; } return false; } bool HttpHeaderCursor::find(const std::string& name, std::string& value) { struct HttpMessageHeader header = { .name = name.c_str(), .name_len = name.size(), }; if (this->find(&header)) { value.assign((const char *)header.value, header.value_len); return true; } return false; } bool HttpHeaderCursor::find_and_erase(const std::string& name) { struct HttpMessageHeader header = { .name = name.c_str(), .name_len = name.size(), }; return this->find_and_erase(&header); } HttpChunkCursor::HttpChunkCursor(const HttpMessage *msg) { if (msg->get_parsed_body(&this->body, &this->body_len)) { this->pos = this->body; this->chunked = msg->is_chunked(); this->end = false; } else { this->body = NULL; this->end = true; } } bool HttpChunkCursor::next(const void **chunk, size_t *size) { if (this->end) return false; if (!this->chunked) { *chunk = this->body; *size = this->body_len; this->end = true; return true; } const char *cur = (const char *)this->pos; char *end; *size = strtol(cur, &end, 16); if (*size == 0) { this->end = true; return false; } cur = strchr(end, '\r'); *chunk = cur + 2; cur += *size + 4; this->pos = cur; return true; } void HttpChunkCursor::rewind() { if (this->body) { this->pos = this->body; this->end = false; } } } workflow-0.11.8/src/protocol/HttpUtil.h000066400000000000000000000167621476003635400200550ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com) */ #ifndef _HTTPUTIL_H_ #define _HTTPUTIL_H_ #include #include #include #include "http_parser.h" #include "HttpMessage.h" /** * @file HttpUtil.h * @brief Http toolbox */ #define HttpMethodGet "GET" #define HttpMethodHead "HEAD" #define HttpMethodPost "POST" #define HttpMethodPut "PUT" #define HttpMethodPatch "PATCH" #define HttpMethodDelete "DELETE" #define HttpMethodConnect "CONNECT" #define HttpMethodOptions "OPTIONS" #define HttpMethodTrace "TRACE" enum { HttpStatusContinue = 100, // RFC 7231, 6.2.1 HttpStatusSwitchingProtocols = 101, // RFC 7231, 6.2.2 HttpStatusProcessing = 102, // RFC 2518, 10.1 HttpStatusOK = 200, // RFC 7231, 6.3.1 HttpStatusCreated = 201, // RFC 7231, 6.3.2 HttpStatusAccepted = 202, // RFC 7231, 6.3.3 HttpStatusNonAuthoritativeInfo = 203, // RFC 7231, 6.3.4 HttpStatusNoContent = 204, // RFC 7231, 6.3.5 HttpStatusResetContent = 205, // RFC 7231, 6.3.6 HttpStatusPartialContent = 206, // RFC 7233, 4.1 HttpStatusMultiStatus = 207, // RFC 4918, 11.1 HttpStatusAlreadyReported = 208, // RFC 5842, 7.1 HttpStatusIMUsed = 226, // RFC 3229, 10.4.1 HttpStatusMultipleChoices = 300, // RFC 7231, 6.4.1 HttpStatusMovedPermanently = 301, // RFC 7231, 6.4.2 HttpStatusFound = 302, // RFC 7231, 6.4.3 HttpStatusSeeOther = 303, // RFC 7231, 6.4.4 HttpStatusNotModified = 304, // RFC 7232, 4.1 HttpStatusUseProxy = 305, // RFC 7231, 6.4.5 HttpStatusTemporaryRedirect = 307, // RFC 7231, 6.4.7 HttpStatusPermanentRedirect = 308, // RFC 7538, 3 HttpStatusBadRequest = 400, // RFC 7231, 6.5.1 HttpStatusUnauthorized = 401, // RFC 7235, 3.1 HttpStatusPaymentRequired = 402, // RFC 7231, 6.5.2 HttpStatusForbidden = 403, // RFC 7231, 6.5.3 HttpStatusNotFound = 404, // RFC 7231, 6.5.4 HttpStatusMethodNotAllowed = 405, // RFC 7231, 6.5.5 HttpStatusNotAcceptable = 406, // RFC 7231, 6.5.6 HttpStatusProxyAuthRequired = 407, // RFC 7235, 3.2 HttpStatusRequestTimeout = 408, // RFC 7231, 6.5.7 HttpStatusConflict = 409, // RFC 7231, 6.5.8 HttpStatusGone = 410, // RFC 7231, 6.5.9 HttpStatusLengthRequired = 411, // RFC 7231, 6.5.10 HttpStatusPreconditionFailed = 412, // RFC 7232, 4.2 HttpStatusRequestEntityTooLarge = 413, // RFC 7231, 6.5.11 HttpStatusRequestURITooLong = 414, // RFC 7231, 6.5.12 HttpStatusUnsupportedMediaType = 415, // RFC 7231, 6.5.13 HttpStatusRequestedRangeNotSatisfiable = 416, // RFC 7233, 4.4 HttpStatusExpectationFailed = 417, // RFC 7231, 6.5.14 HttpStatusTeapot = 418, // RFC 7168, 2.3.3 HttpStatusEnhanceYourCaim = 420, // Twitter Search HttpStatusMisdirectedRequest = 421, // RFC 7540, 9.1.2 HttpStatusUnprocessableEntity = 422, // RFC 4918, 11.2 HttpStatusLocked = 423, // RFC 4918, 11.3 HttpStatusFailedDependency = 424, // RFC 4918, 11.4 HttpStatusTooEarly = 425, // RFC 8470, 5.2. HttpStatusUpgradeRequired = 426, // RFC 7231, 6.5.15 HttpStatusPreconditionRequired = 428, // RFC 6585, 3 HttpStatusTooManyRequests = 429, // RFC 6585, 4 HttpStatusRequestHeaderFieldsTooLarge = 431, // RFC 6585, 5 HttpStatusNoResponse = 444, // Nginx HttpStatusBlocked = 450, // Windows HttpStatusUnavailableForLegalReasons = 451, // RFC 7725, 3 HttpStatusTooLargeForNginx = 494, // Nginx HttpStatusInternalServerError = 500, // RFC 7231, 6.6.1 HttpStatusNotImplemented = 501, // RFC 7231, 6.6.2 HttpStatusBadGateway = 502, // RFC 7231, 6.6.3 HttpStatusServiceUnavailable = 503, // RFC 7231, 6.6.4 HttpStatusGatewayTimeout = 504, // RFC 7231, 6.6.5 HttpStatusHTTPVersionNotSupported = 505, // RFC 7231, 6.6.6 HttpStatusVariantAlsoNegotiates = 506, // RFC 2295, 8.1 HttpStatusInsufficientStorage = 507, // RFC 4918, 11.5 HttpStatusLoopDetected = 508, // RFC 5842, 7.2 HttpStatusNotExtended = 510, // RFC 2774, 7 HttpStatusNetworkAuthenticationRequired = 511, // RFC 6585, 6 }; namespace protocol { // static class class HttpUtil { public: static void set_response_status(HttpResponse *resp, int status_code); static std::string decode_chunked_body(const HttpMessage *msg); }; class HttpHeaderMap { public: HttpHeaderMap(const HttpMessage *message); bool key_exists(std::string key); std::string get(std::string key); bool get(std::string key, std::string& value); std::vector get_strict(std::string key); bool get_strict(std::string key, std::vector& values); private: std::unordered_map> header_map_; }; class HttpHeaderCursor { public: HttpHeaderCursor(const HttpMessage *message); virtual ~HttpHeaderCursor(); public: bool next(struct HttpMessageHeader *header); bool find(struct HttpMessageHeader *header); bool erase(); bool find_and_erase(struct HttpMessageHeader *header); void rewind(); /* std::string interface */ public: bool next(std::string& name, std::string& value); bool find(const std::string& name, std::string& value); bool find_and_erase(const std::string& name); protected: http_header_cursor_t cursor; }; class HttpChunkCursor { public: HttpChunkCursor(const HttpMessage *message); virtual ~HttpChunkCursor() { } public: bool next(const void **chunk, size_t *size); void rewind(); protected: const void *body; size_t body_len; const void *pos; bool chunked; bool end; }; //////////////////// inline HttpHeaderCursor::HttpHeaderCursor(const HttpMessage *message) { http_header_cursor_init(&this->cursor, message->get_parser()); } inline HttpHeaderCursor::~HttpHeaderCursor() { http_header_cursor_deinit(&this->cursor); } inline bool HttpHeaderCursor::next(struct HttpMessageHeader *header) { return http_header_cursor_next(&header->name, &header->name_len, &header->value, &header->value_len, &this->cursor) == 0; } inline bool HttpHeaderCursor::find(struct HttpMessageHeader *header) { return http_header_cursor_find(header->name, header->name_len, &header->value, &header->value_len, &this->cursor) == 0; } inline bool HttpHeaderCursor::erase() { return http_header_cursor_erase(&this->cursor) == 0; } inline bool HttpHeaderCursor::find_and_erase(struct HttpMessageHeader *header) { if (this->find(header)) return this->erase(); return false; } inline void HttpHeaderCursor::rewind() { http_header_cursor_rewind(&this->cursor); } } #endif workflow-0.11.8/src/protocol/KafkaDataTypes.cc000066400000000000000000000362631476003635400212700ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wang Zhulei (wangzhulei@sogou-inc.com) */ #include #include #include #include "KafkaDataTypes.h" #define MIN(x, y) ((x) <= (y) ? (x) : (y)) namespace protocol { std::string KafkaConfig::get_sasl_info() const { std::string info; if (strcasecmp(this->ptr->mechanisms, "plain") == 0) { info += this->ptr->mechanisms; info += "|"; info += this->ptr->username; info += "|"; info += this->ptr->password; info += "|"; } else if (strncasecmp(this->ptr->mechanisms, "SCRAM", 5) == 0) { info += this->ptr->mechanisms; info += "|"; info += this->ptr->username; info += "|"; info += this->ptr->password; info += "|"; } return info; } static bool compare_member(const kafka_member_t *m1, const kafka_member_t *m2) { return strcmp(m1->member_id, m2->member_id) < 0; } inline void KafkaMetaSubscriber::sort_by_member() { std::sort(this->member_vec.begin(), this->member_vec.end(), compare_member); } static bool operator<(const KafkaMetaSubscriber& s1, const KafkaMetaSubscriber& s2) { return strcmp(s1.get_meta()->get_topic(), s2.get_meta()->get_topic()) < 0; } /* * For example, suppose there are two consumers C0 and C1, two topics t0 and t1, and each topic has 3 partitions, * resulting in partitions t0p0, t0p1, t0p2, t1p0, t1p1, and t1p2. * * The assignment will be: * C0: [t0p0, t0p1, t1p0, t1p1] * C1: [t0p2, t1p2] */ int KafkaCgroup::kafka_range_assignor(kafka_member_t **members, int member_elements, void *meta_topic) { std::vector *subscribers = static_cast *>(meta_topic); /* The range assignor works on a per-topic basis. */ for (auto& subscriber : *subscribers) { subscriber.sort_by_member(); int num_partitions_per_consumer = subscriber.get_meta()->get_partition_elements() / subscriber.get_member()->size(); /* If it does not evenly divide, then the first few consumers * will have one extra partition. */ int consumers_with_extra_partition = subscriber.get_meta()->get_partition_elements() % subscriber.get_member()->size(); for (int i = 0 ; i < (int)subscriber.get_member()->size(); i++) { int start = num_partitions_per_consumer * i + MIN(i, consumers_with_extra_partition); int length = num_partitions_per_consumer + (i + 1 > consumers_with_extra_partition ? 0 : 1); if (length == 0) continue; for (int j = start; j < length + start; ++j) { KafkaToppar *toppar = new KafkaToppar; if (!toppar->set_topic_partition(subscriber.get_meta()->get_topic(), j)) { delete toppar; return -1; } list_add_tail(&toppar->list, &subscriber.get_member()->at(i)->assigned_toppar_list); } } } return 0; } /* * For example, suppose there are two consumers C0 and C1, two topics t0 and * t1, and each topic has 3 partitions, resulting in partitions t0p0, t0p1, * t0p2, t1p0, t1p1, and t1p2. * * The assignment will be: * C0: [t0p0, t0p2, t1p1] * C1: [t0p1, t1p0, t1p2] */ int KafkaCgroup::kafka_roundrobin_assignor(kafka_member_t **members, int member_elements, void *meta_topic) { std::vector *subscribers = static_cast *>(meta_topic); int next = -1; std::sort(subscribers->begin(), subscribers->end()); std::sort(members, members + member_elements, compare_member); for (const auto& subscriber : *subscribers) { int partition_elements = subscriber.get_meta()->get_partition_elements(); for (int partition = 0; partition < partition_elements; ++partition) { next = (next + 1) % subscriber.get_member()->size(); struct list_head *pos; KafkaToppar *toppar; int i = 0; for (; i < member_elements; i++) { bool flag = false; list_for_each(pos, &members[next + i]->toppar_list) { toppar = list_entry(pos, KafkaToppar, list); if (strcmp(subscriber.get_meta()->get_topic(), toppar->get_topic()) == 0) { flag = true; break; } } if (flag) break; } if (i >= member_elements) return -1; toppar = new KafkaToppar; if (!toppar->set_topic_partition(subscriber.get_meta()->get_topic(), partition)) { delete toppar; return -1; } list_add_tail(toppar->get_list(), &members[next]->assigned_toppar_list); } } return 0; } bool KafkaMeta::create_partitions(int partition_cnt) { if (partition_cnt <= 0) return true; kafka_partition_t **partitions; partitions = (kafka_partition_t **)malloc(sizeof(void *) * partition_cnt); if (!partitions) return false; int i; for (i = 0; i < partition_cnt; ++i) { partitions[i] = (kafka_partition_t *)malloc(sizeof(kafka_partition_t)); if (!partitions[i]) break; kafka_partition_init(partitions[i]); } if (i != partition_cnt) { while (--i >= 0) { kafka_partition_deinit(partitions[i]); free(partitions[i]); } free(partitions); return false; } for (i = 0; i < this->ptr->partition_elements; ++i) { kafka_partition_deinit(this->ptr->partitions[i]); free(this->ptr->partitions[i]); } free(this->ptr->partitions); this->ptr->partitions = partitions; this->ptr->partition_elements = partition_cnt; return true; } void KafkaCgroup::add_subscriber(KafkaMetaList *meta_list, std::vector *subscribers) { meta_list->rewind(); KafkaMeta *meta; while ((meta = meta_list->get_next()) != NULL) { KafkaMetaSubscriber subscriber; subscriber.set_meta(meta); for (int i = 0; i < this->get_member_elements(); ++i) { struct list_head *pos; KafkaToppar *toppar; bool flag = false; list_for_each(pos, &this->get_members()[i]->toppar_list) { toppar = list_entry(pos, KafkaToppar, list); if (strcmp(meta->get_topic(), toppar->get_topic()) == 0) { flag = true; break; } } if (flag) subscriber.add_member(this->get_members()[i]); } if (!subscriber.get_member()->empty()) subscribers->emplace_back(subscriber); } } int KafkaCgroup::run_assignor(KafkaMetaList *meta_list, const char *protocol_name) { std::vector subscribers; this->add_subscriber(meta_list, &subscribers); struct list_head *pos; kafka_group_protocol_t *protocol; bool flag = false; list_for_each(pos, this->get_group_protocol()) { protocol = list_entry(pos, kafka_group_protocol_t, list); if (strcmp(protocol_name, protocol->protocol_name) == 0) { flag = true; break; } } if (!flag) { errno = EBADMSG; return -1; } return protocol->assignor(this->get_members(), this->get_member_elements(), &subscribers); } KafkaCgroup::KafkaCgroup() { this->ptr = new kafka_cgroup_t; kafka_cgroup_init(this->ptr); kafka_group_protocol_t *protocol = new kafka_group_protocol_t; protocol->protocol_name = new char[strlen("range") + 1]; memcpy(protocol->protocol_name, "range", strlen("range") + 1); protocol->assignor = kafka_range_assignor; list_add_tail(&protocol->list, &this->ptr->group_protocol_list); protocol = new kafka_group_protocol_t; protocol->protocol_name = new char[strlen("roundrobin") + 1]; memcpy(protocol->protocol_name, "roundrobin", strlen("roundrobin") + 1); protocol->assignor = kafka_roundrobin_assignor; list_add_tail(&protocol->list, &this->ptr->group_protocol_list); this->ref = new std::atomic(1); this->coordinator = NULL; } KafkaCgroup::~KafkaCgroup() { if (--*this->ref == 0) { for (int i = 0; i < this->ptr->member_elements; ++i) { kafka_member_t *member = this->ptr->members[i]; KafkaToppar *toppar; struct list_head *pos, *tmp; list_for_each_safe(pos, tmp, &member->toppar_list) { toppar = list_entry(pos, KafkaToppar, list); list_del(pos); delete toppar; } list_for_each_safe(pos, tmp, &member->assigned_toppar_list) { toppar = list_entry(pos, KafkaToppar, list); list_del(pos); delete toppar; } } kafka_cgroup_deinit(this->ptr); struct list_head *tmp, *pos; KafkaToppar *toppar; list_for_each_safe(pos, tmp, &this->ptr->assigned_toppar_list) { toppar = list_entry(pos, KafkaToppar, list); list_del(pos); delete toppar; } kafka_group_protocol_t *protocol; list_for_each_safe(pos, tmp, &this->ptr->group_protocol_list) { protocol = list_entry(pos, kafka_group_protocol_t, list); list_del(pos); delete []protocol->protocol_name; delete protocol; } delete []this->ptr->group_name; delete this->ptr; delete this->ref; } delete this->coordinator; } KafkaCgroup::KafkaCgroup(KafkaCgroup&& move) { this->ptr = move.ptr; this->ref = move.ref; move.ptr = new kafka_cgroup_t; kafka_cgroup_init(move.ptr); move.ref = new std::atomic(1); this->coordinator = move.coordinator; move.coordinator = NULL; } KafkaCgroup& KafkaCgroup::operator= (KafkaCgroup&& move) { if (this != &move) { this->~KafkaCgroup(); this->ptr = move.ptr; this->ref = move.ref; move.ptr = new kafka_cgroup_t; kafka_cgroup_init(move.ptr); move.ref = new std::atomic(1); this->coordinator = move.coordinator; move.coordinator = NULL; } return *this; } KafkaCgroup::KafkaCgroup(const KafkaCgroup& copy) { this->ptr = copy.ptr; this->ref = copy.ref; ++*this->ref; if (copy.coordinator) this->coordinator = new KafkaBroker(copy.coordinator->get_raw_ptr()); else this->coordinator = NULL; } KafkaCgroup& KafkaCgroup::operator= (const KafkaCgroup& copy) { this->~KafkaCgroup(); this->ptr = copy.ptr; this->ref = copy.ref; ++*this->ref; if (copy.coordinator) this->coordinator = new KafkaBroker(copy.coordinator->get_raw_ptr()); else this->coordinator = NULL; return *this; } bool KafkaCgroup::create_members(int member_cnt) { if (member_cnt == 0) return true; kafka_member_t **members; members = (kafka_member_t **)malloc(sizeof(void *) * member_cnt); if (!members) return false; int i; for (i = 0; i < member_cnt; ++i) { members[i] = (kafka_member_t *)malloc(sizeof(kafka_member_t)); if (!members[i]) break; kafka_member_init(members[i]); INIT_LIST_HEAD(&members[i]->toppar_list); INIT_LIST_HEAD(&members[i]->assigned_toppar_list); } if (i != member_cnt) { while (--i >= 0) { KafkaToppar *toppar; struct list_head *pos, *tmp; list_for_each_safe(pos, tmp, &members[i]->toppar_list) { toppar = list_entry(pos, KafkaToppar, list); list_del(pos); delete toppar; } list_for_each_safe(pos, tmp, &members[i]->assigned_toppar_list) { toppar = list_entry(pos, KafkaToppar, list); list_del(pos); delete toppar; } kafka_member_deinit(members[i]); free(members[i]); } free(members); return false; } for (i = 0; i < this->ptr->member_elements; ++i) { KafkaToppar *toppar; struct list_head *pos, *tmp; list_for_each_safe(pos, tmp, &this->ptr->members[i]->toppar_list) { toppar = list_entry(pos, KafkaToppar, list); list_del(pos); delete toppar; } list_for_each_safe(pos, tmp, &this->ptr->members[i]->assigned_toppar_list) { toppar = list_entry(pos, KafkaToppar, list); list_del(pos); delete toppar; } kafka_member_deinit(this->ptr->members[i]); free(this->ptr->members[i]); } free(this->ptr->members); this->ptr->members = members; this->ptr->member_elements = member_cnt; return true; } void KafkaCgroup::add_assigned_toppar(KafkaToppar *toppar) { list_add_tail(toppar->get_list(), &this->ptr->assigned_toppar_list); } void KafkaCgroup::assigned_toppar_rewind() { this->curpos = &this->ptr->assigned_toppar_list; } KafkaToppar *KafkaCgroup::get_assigned_toppar_next() { if (this->curpos->next == &this->ptr->assigned_toppar_list) return NULL; this->curpos = this->curpos->next; return list_entry(this->curpos, KafkaToppar, list); } void KafkaCgroup::del_assigned_toppar_cur() { assert(this->curpos != &this->ptr->assigned_toppar_list); this->curpos = this->curpos->prev; list_del(this->curpos->next); } bool KafkaRecord::add_header_pair(const void *key, size_t key_len, const void *val, size_t val_len) { kafka_record_header_t *header; header = (kafka_record_header_t *)malloc(sizeof(kafka_record_header_t)); if (!header) return false; kafka_record_header_init(header); if (kafka_record_header_set_kv(key, key_len, val, val_len, header) < 0) { free(header); return false; } list_add_tail(&header->list, &this->ptr->header_list); return true; } bool KafkaRecord::add_header_pair(const std::string& key, const std::string& val) { return add_header_pair(key.c_str(), key.size(), val.c_str(), val.size()); } KafkaToppar::~KafkaToppar() { if (--*this->ref == 0) { kafka_topic_partition_deinit(this->ptr); struct list_head *tmp, *pos; KafkaRecord *record; list_for_each_safe(pos, tmp, &this->ptr->record_list) { record = list_entry(pos, KafkaRecord, list); list_del(pos); delete record; } delete this->ptr; delete this->ref; } } void KafkaBuffer::list_splice(KafkaBuffer *buffer) { struct list_head *pre_insert; struct list_head *pre_tail; this->buf_size -= this->insert_buf_size; pre_insert = this->insert_pos->next; __list_splice(buffer->get_head(), this->insert_pos, pre_insert); pre_tail = this->block_list.get_tail(); buffer->get_head()->prev->next = this->block_list.get_head(); this->block_list.get_head()->prev = buffer->get_head()->prev; buffer->get_head()->next = pre_insert; buffer->get_head()->prev = pre_tail; pre_tail->next = buffer->get_head(); pre_insert->prev = buffer->get_head(); this->buf_size += buffer->get_size(); } size_t KafkaBuffer::peek(const char **buf) { if (!this->inited) { this->inited = true; this->cur_pos = std::make_pair(this->block_list.get_next(), 0); } if (this->cur_pos.first == this->block_list.get_tail_entry() && this->cur_pos.second == this->block_list.get_tail_entry()->get_len()) { *buf = NULL; return 0; } KafkaBlock *block = this->cur_pos.first; if (this->cur_pos.second >= block->get_len()) { block = this->block_list.get_next(); this->cur_pos = std::make_pair(block, 0); } *buf = (char *)block->get_block() + this->cur_pos.second; return block->get_len() - this->cur_pos.second; } KafkaToppar *get_toppar(const char *topic, int partition, KafkaTopparList *toppar_list) { struct list_head *pos; KafkaToppar *toppar; list_for_each(pos, toppar_list->get_head()) { toppar = list_entry(pos, KafkaToppar, list); if (strcmp(toppar->get_topic(), topic) == 0 && toppar->get_partition() == partition) return toppar; } return NULL; } const KafkaMeta *get_meta(const char *topic, KafkaMetaList *meta_list) { struct list_head *pos; const KafkaMeta *meta; list_for_each(pos, meta_list->get_head()) { meta = list_entry(pos, KafkaMeta, list); if (strcmp(meta->get_topic(), topic) == 0) return meta; } return NULL; } } /* namespace protocol */ workflow-0.11.8/src/protocol/KafkaDataTypes.h000066400000000000000000000773331476003635400211350ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wang Zhulei (wangzhulei@sogou-inc.com) */ #ifndef _KAFKA_DATATYPES_H_ #define _KAFKA_DATATYPES_H_ #include #include #include #include #include #include #include #include #include #include #include #include "list.h" #include "rbtree.h" #include "kafka_parser.h" namespace protocol { template class KafkaList { public: KafkaList() { this->t_list = new struct list_head; INIT_LIST_HEAD(this->t_list); this->ref = new std::atomic(1); this->curpos = this->t_list; } ~KafkaList() { if (--*this->ref == 0) { struct list_head *pos, *tmp; T *t; list_for_each_safe(pos, tmp, this->t_list) { t = list_entry(pos, T, list); list_del(pos); delete t; } delete this->t_list; delete this->ref; } } KafkaList(KafkaList&& move) { this->t_list = move.t_list; move.t_list = new struct list_head; INIT_LIST_HEAD(move.t_list); this->ref = move.ref; this->curpos = this->t_list; move.ref = new std::atomic(1); } KafkaList& operator= (KafkaList&& move) { if (this != &move) { this->~KafkaList(); this->t_list = move.t_list; move.t_list = new struct list_head; INIT_LIST_HEAD(move.t_list); this->ref = move.ref; this->curpos = this->t_list; move.ref = new std::atomic(1); } return *this; } KafkaList(const KafkaList& copy) { this->ref = copy.ref; ++*this->ref; this->t_list = copy.t_list; this->curpos = copy.curpos; } KafkaList& operator= (const KafkaList& copy) { if (this != ©) { this->~KafkaList(); this->ref = copy.ref; ++*this->ref; this->t_list = copy.t_list; this->curpos = copy.curpos; } return *this; } T *add_item(T&& move) { T *t = new T; *t = std::move(move); list_add_tail(t->get_list(), this->t_list); return t; } void add_item(T& obj) { T *t = new T; *t = obj; list_add_tail(t->get_list(), this->t_list); } struct list_head *get_head() { return this->t_list; } struct list_head *get_tail() { return this->t_list->prev; } T *get_first_entry() { if (this->t_list == this->t_list->next) return NULL; return list_entry(this->t_list->next, T, list); } T *get_tail_entry() { if (this->t_list == this->get_tail()) return NULL; return list_entry(this->get_tail(), T, list); } T *get_entry(struct list_head *pos) { return list_entry(pos, T, list); } void rewind() { this->curpos = this->t_list; } T *get_next() { if (this->curpos->next == this->t_list) return NULL; this->curpos = this->curpos->next; return list_entry(this->curpos, T, list); } void insert_pos(struct list_head *list, struct list_head *pos) { __list_add(list, pos, pos->next); } void del_cur() { assert(this->curpos != this->t_list); this->curpos = this->curpos->prev; list_del(this->curpos->next); } private: struct list_head *t_list; std::atomic *ref; struct list_head *curpos; }; template class KafkaMap { public: KafkaMap() { this->t_map = new struct rb_root; this->t_map->rb_node = NULL; this->ref = new std::atomic(1); } ~KafkaMap() { if (--*this->ref == 0) { T *t; while (this->t_map->rb_node) { t = rb_entry(this->t_map->rb_node, T, rb); rb_erase(this->t_map->rb_node, this->t_map); delete t; } delete this->t_map; delete this->ref; } } KafkaMap(const KafkaMap& copy) { this->ref = copy.ref; ++*this->ref; this->t_map = copy.t_map; } KafkaMap& operator= (const KafkaMap& copy) { if (this != ©) { this->~KafkaMap(); this->ref = copy.ref; ++*this->ref; this->t_map = copy.t_map; } return *this; } T *find_item(const T& v) const { rb_node **p = &this->t_map->rb_node; T *t; while (*p) { t = rb_entry(*p, T, rb); if (v < *t) p = &(*p)->rb_left; else if (v > *t) p = &(*p)->rb_right; else break; } return *p ? t : NULL; } void add_item(T& obj) { rb_node **p = &this->t_map->rb_node; rb_node *parent = NULL; T *t; while (*p) { parent = *p; t = rb_entry(*p, T, rb); if (obj < *t) p = &(*p)->rb_left; else if (obj > *t) p = &(*p)->rb_right; else break; } if (*p == NULL) { T *nt = new T; *nt = obj; rb_link_node(nt->get_rb(), parent, p); rb_insert_color(nt->get_rb(), this->t_map); } } T *find_item(int id) const { rb_node **p = &this->t_map->rb_node; T *t; while (*p) { t = rb_entry(*p, T, rb); if (id < t->get_id()) p = &(*p)->rb_left; else if (id > t->get_id()) p = &(*p)->rb_right; else break; } return *p ? t : NULL; } void add_item(T& obj, int id) { rb_node **p = &this->t_map->rb_node; rb_node *parent = NULL; T *t; while (*p) { parent = *p; t = rb_entry(*p, T, rb); if (id < t->get_id()) p = &(*p)->rb_left; else if (id > t->get_id()) p = &(*p)->rb_right; else break; } if (*p == NULL) { T *nt = new T; *nt = obj; rb_link_node(nt->get_rb(), parent, p); rb_insert_color(nt->get_rb(), this->t_map); } } T *get_first_entry() { struct rb_node *p = rb_first(this->t_map); return rb_entry(p, T, rb); } T *get_tail_entry() { struct rb_node *p = rb_last(this->t_map); return rb_entry(p, T, rb); } private: struct rb_root *t_map; std::atomic *ref; }; class KafkaConfig { public: void set_produce_timeout(int ms) { this->ptr->produce_timeout = ms; } int get_produce_timeout() const { return this->ptr->produce_timeout; } void set_produce_msg_max_bytes(int bytes) { this->ptr->produce_msg_max_bytes = bytes; } int get_produce_msg_max_bytes() const { return this->ptr->produce_msg_max_bytes; } void set_produce_msgset_cnt(int cnt) { this->ptr->produce_msgset_cnt = cnt; } int get_produce_msgset_cnt() const { return this->ptr->produce_msgset_cnt; } void set_produce_msgset_max_bytes(int bytes) { this->ptr->produce_msgset_max_bytes = bytes; } int get_produce_msgset_max_bytes() const { return this->ptr->produce_msgset_max_bytes; } void set_fetch_timeout(int ms) { this->ptr->fetch_timeout = ms; } int get_fetch_timeout() const { return this->ptr->fetch_timeout; } void set_fetch_min_bytes(int bytes) { this->ptr->fetch_min_bytes = bytes; } int get_fetch_min_bytes() const { return this->ptr->fetch_min_bytes; } void set_fetch_max_bytes(int bytes) { this->ptr->fetch_max_bytes = bytes; } int get_fetch_max_bytes() const { return this->ptr->fetch_max_bytes; } void set_fetch_msg_max_bytes(int bytes) { this->ptr->fetch_msg_max_bytes = bytes; } int get_fetch_msg_max_bytes() const { return this->ptr->fetch_msg_max_bytes; } void set_offset_timestamp(long long tm) { this->ptr->offset_timestamp = tm; } long long get_offset_timestamp() const { return this->ptr->offset_timestamp; } void set_commit_timestamp(long long commit_timestamp) { this->ptr->commit_timestamp = commit_timestamp; } long long get_commit_timestamp() const { return this->ptr->commit_timestamp; } void set_session_timeout(int ms) { this->ptr->session_timeout = ms; } int get_session_timeout() const { return this->ptr->session_timeout; } void set_rebalance_timeout(int ms) { this->ptr->rebalance_timeout = ms; } int get_rebalance_timeout() const { return this->ptr->rebalance_timeout; } void set_retention_time_period(long long ms) { this->ptr->retention_time_period = ms; } long long get_retention_time_period() const { return this->ptr->retention_time_period; } void set_produce_acks(int acks) { this->ptr->produce_acks = acks; } int get_produce_acks() const { return this->ptr->produce_acks; } void set_allow_auto_topic_creation(bool allow_auto_topic_creation) { this->ptr->allow_auto_topic_creation = allow_auto_topic_creation; } bool get_allow_auto_topic_creation() const { return this->ptr->allow_auto_topic_creation; } void set_api_version_request(int api_ver) { this->ptr->api_version_request = api_ver; } int get_api_version_request() const { return this->ptr->api_version_request; } bool set_broker_version(const char *version) { char *p = strdup(version); if (!p) return false; free(this->ptr->broker_version); this->ptr->broker_version = p; return true; } const char *get_broker_version() const { return this->ptr->broker_version; } void set_compress_type(int type) { this->ptr->compress_type = type; } int get_compress_type() const { return this->ptr->compress_type; } const char *get_client_id() { return this->ptr->client_id; } bool set_client_id(const char *client_id) { char *p = strdup(client_id); if (!p) return false; free(this->ptr->client_id); this->ptr->client_id = p; return true; } bool get_check_crcs() const { return this->ptr->check_crcs != 0; } void set_check_crcs(bool check_crcs) { this->ptr->check_crcs = check_crcs; } int get_offset_store() const { return this->ptr->offset_store; } void set_offset_store(int offset_store) { this->ptr->offset_store = offset_store; } const char *get_rack_id() const { return this->ptr->rack_id; } bool set_rack_id(const char *rack_id) { char *p = strdup(rack_id); if (!p) return false; free(this->ptr->rack_id); this->ptr->rack_id = p; return true; } const char *get_sasl_mech() const { return this->ptr->mechanisms; } bool set_sasl_mech(const char *mechanisms) { char *p = strdup(mechanisms); if (!p) return false; free(this->ptr->mechanisms); this->ptr->mechanisms = p; if (kafka_sasl_set_mechanisms(this->ptr) != 0) return false; return true; } const char *get_sasl_username() const { return this->ptr->username; } bool set_sasl_username(const char *username) { return kafka_sasl_set_username(username, this->ptr) == 0; } const char *get_sasl_password() const { return this->ptr->password; } bool set_sasl_password(const char *password) { return kafka_sasl_set_password(password, this->ptr) == 0; } std::string get_sasl_info() const; bool new_client(kafka_sasl_t *sasl) { return this->ptr->client_new(this->ptr, sasl) == 0; } public: KafkaConfig() { this->ptr = new kafka_config_t; kafka_config_init(this->ptr); this->ref = new std::atomic(1); } virtual ~KafkaConfig() { if (--*this->ref == 0) { kafka_config_deinit(this->ptr); delete this->ptr; delete this->ref; } } KafkaConfig(KafkaConfig&& move) { this->ptr = move.ptr; this->ref = move.ref; move.ptr = new kafka_config_t; kafka_config_init(move.ptr); move.ref = new std::atomic(1); } KafkaConfig& operator= (KafkaConfig&& move) { if (this != &move) { this->~KafkaConfig(); this->ptr = move.ptr; this->ref = move.ref; move.ptr = new kafka_config_t; kafka_config_init(move.ptr); move.ref = new std::atomic(1); } return *this; } KafkaConfig(const KafkaConfig& copy) { this->ptr = copy.ptr; this->ref = copy.ref; ++*this->ref; } KafkaConfig& operator= (const KafkaConfig& copy) { if (this != ©) { this->~KafkaConfig(); this->ptr = copy.ptr; this->ref = copy.ref; ++*this->ref; } return *this; } kafka_config_t *get_raw_ptr() { return this->ptr; } private: kafka_config_t *ptr; std::atomic *ref; }; class KafkaRecord { public: bool set_key(const void *key, size_t key_len) { return kafka_record_set_key(key, key_len, this->ptr) == 0; } void get_key(const void **key, size_t *key_len) const { *key = this->ptr->key; *key_len = this->ptr->key_len; } size_t get_key_len() const { return this->ptr->key_len; } bool set_value(const void *value, size_t value_len) { return kafka_record_set_value(value, value_len, this->ptr) == 0; } void get_value(const void **value, size_t *value_len) const { *value = this->ptr->value; *value_len = this->ptr->value_len; } size_t get_value_len() const { return this->ptr->value_len; } bool add_header_pair(const void *key, size_t key_len, const void *val, size_t val_len); bool add_header_pair(const std::string& key, const std::string& val); struct list_head *get_list() { return &this->list; } const char *get_topic() const { return this->ptr->toppar->topic_name; } void set_status(short err) { this->ptr->status = err; } short get_status() const { return this->ptr->status; } int get_partition() const { return this->ptr->toppar->partition; } long long get_offset() const { return this->ptr->offset; } void set_offset(long long offset) { this->ptr->offset = offset; } long long get_timestamp() const { return this->ptr->timestamp; } void set_timestamp(long long timestamp) { this->ptr->timestamp = timestamp; } public: KafkaRecord() { this->ptr = new kafka_record_t; kafka_record_init(this->ptr); this->ref = new std::atomic(1); } ~KafkaRecord() { if (--*this->ref == 0) { kafka_record_deinit(this->ptr); delete this->ptr; delete this->ref; } } KafkaRecord(KafkaRecord&& move) { this->ptr = move.ptr; this->ref = move.ref; move.ptr = new kafka_record_t; kafka_record_init(move.ptr); move.ref = new std::atomic(1); } KafkaRecord& operator= (KafkaRecord&& move) { if (this != &move) { this->~KafkaRecord(); this->ptr = move.ptr; this->ref = move.ref; move.ptr = new kafka_record_t; kafka_record_init(move.ptr); move.ref = new std::atomic(1); } return *this; } KafkaRecord(const KafkaRecord& copy) { this->ptr = copy.ptr; this->ref = copy.ref; ++*this->ref; } KafkaRecord& operator= (const KafkaRecord& copy) { if (this != ©) { this->~KafkaRecord(); this->ptr = copy.ptr; this->ref = copy.ref; ++*this->ref; } return *this; } kafka_record_t *get_raw_ptr() const { return this->ptr; } struct list_head *get_header_list() const { return &this->ptr->header_list; } private: struct list_head list; kafka_record_t *ptr; std::atomic *ref; friend class KafkaMessage; friend class KafkaResponse; friend class KafkaToppar; }; class KafkaMeta; class KafkaBroker; class KafkaToppar; using KafkaMetaList = KafkaList; using KafkaBrokerList = KafkaList; using KafkaBrokerMap = KafkaMap; using KafkaTopparList = KafkaList; using KafkaRecordList = KafkaList; extern KafkaToppar *get_toppar(const char *topic, int partition, KafkaTopparList *toppar_list); extern const KafkaMeta *get_meta(const char *topic, KafkaMetaList *meta_list); class KafkaToppar { public: bool set_topic_partition(const std::string& topic, int partition) { return kafka_topic_partition_set_tp(topic.c_str(), partition, this->ptr) == 0; } int get_preferred_read_replica() const { return this->ptr->preferred_read_replica; } bool set_topic(const char *topic) { this->ptr->topic_name = strdup(topic); return this->ptr->topic_name != NULL; } const char *get_topic() const { return this->ptr->topic_name; } int get_partition() const { return this->ptr->partition; } long long get_offset() const { return this->ptr->offset; } void set_offset(long long offset) { this->ptr->offset = offset; } long long get_offset_timestamp() const { return this->ptr->offset_timestamp; } void set_offset_timestamp(long long tm) { this->ptr->offset_timestamp = tm; } long long get_high_watermark() const { return this->ptr->high_watermark; } void set_high_watermark(long long offset) const { this->ptr->high_watermark = offset; } long long get_low_watermark() const { return this->ptr->low_watermark; } void set_low_watermark(long long offset) { this->ptr->low_watermark = offset; } void clear_records() { INIT_LIST_HEAD(&this->ptr->record_list); this->curpos = &this->ptr->record_list; this->startpos = this->endpos = this->curpos; } public: KafkaToppar() { this->ptr = new kafka_topic_partition_t; kafka_topic_partition_init(this->ptr); this->ref = new std::atomic(1); this->curpos = &this->ptr->record_list; this->startpos = this->endpos = this->curpos; } ~KafkaToppar(); KafkaToppar(KafkaToppar&& move) { this->ptr = move.ptr; this->ref = move.ref; move.ptr = new kafka_topic_partition_t; kafka_topic_partition_init(move.ptr); move.ref = new std::atomic(1); this->curpos = &this->ptr->record_list; this->startpos = this->endpos = this->curpos; } KafkaToppar& operator= (KafkaToppar&& move) { if (this != &move) { this->~KafkaToppar(); this->ptr = move.ptr; this->ref = move.ref; move.ptr = new kafka_topic_partition_t; kafka_topic_partition_init(move.ptr); move.ref = new std::atomic(1); this->curpos = &this->ptr->record_list; this->startpos = this->endpos = this->curpos; } return *this; } KafkaToppar(const KafkaToppar& copy) { this->ptr = copy.ptr; this->ref = copy.ref; ++*this->ref; this->curpos = copy.curpos; this->startpos = copy.startpos; this->endpos = copy.endpos; } KafkaToppar& operator= (const KafkaToppar& copy) { if (this != ©) { this->~KafkaToppar(); this->ptr = copy.ptr; this->ref = copy.ref; ++*this->ref; this->curpos = copy.curpos; this->startpos = copy.startpos; this->endpos = copy.endpos; } return *this; } kafka_topic_partition_t *get_raw_ptr() { return this->ptr; } struct list_head *get_list() { return &this->list; } struct list_head *get_record() { return &this->ptr->record_list; } void set_error(short error) { this->ptr->error = error; } int get_error() const { return this->ptr->error; } void add_record(KafkaRecord&& record) { KafkaRecord *tmp = new KafkaRecord; *tmp =std::move(record); list_add_tail(tmp->get_list(), &this->ptr->record_list); } void record_rewind() { this->curpos = &this->ptr->record_list; } KafkaRecord *get_record_next() { if (this->curpos->next == &this->ptr->record_list) return NULL; this->curpos = this->curpos->next; return list_entry(this->curpos, KafkaRecord, list); } void del_record_cur() { assert(this->curpos != &this->ptr->record_list); this->curpos = this->curpos->prev; list_del(this->curpos->next); } struct list_head *get_record_startpos() { return this->startpos; } struct list_head *get_record_endpos() { return this->endpos; } void restore_record_curpos() { this->curpos = this->startpos; this->endpos = NULL; } void save_record_startpos() { this->startpos = this->curpos; } void save_record_endpos() { this->endpos = this->curpos->next; } bool record_reach_end() { return this->endpos == &this->ptr->record_list; } void record_rollback() { this->curpos = this->curpos->prev; } KafkaRecord *get_tail_record() { if (&this->ptr->record_list != this->ptr->record_list.prev) { return (KafkaRecord *)list_entry(this->ptr->record_list.prev, KafkaRecord, list); } else { return NULL; } } private: struct list_head list; kafka_topic_partition_t *ptr; std::atomic *ref; struct list_head *curpos; struct list_head *startpos; struct list_head *endpos; friend class KafkaMessage; friend class KafkaRequest; friend class KafkaResponse; friend class KafkaList; friend class KafkaCgroup; friend KafkaToppar *get_toppar(const char *topic, int partition, KafkaTopparList *toppar_list); }; class KafkaBroker { public: const char *get_host() const { return this->ptr->host; } int get_port() const { return this->ptr->port; } std::string get_host_port() const { std::string host_port(this->ptr->host); host_port += ":"; host_port += std::to_string(this->ptr->port); return host_port; } int get_error() { return this->ptr->error; } public: KafkaBroker() { this->ptr = new kafka_broker_t; kafka_broker_init(this->ptr); this->ref = new std::atomic(1); } ~KafkaBroker() { if (this->ref && --*this->ref == 0) { kafka_broker_deinit(this->ptr); delete this->ptr; delete this->ref; } } KafkaBroker(kafka_broker_t *ptr) { this->ptr = ptr; this->ref = NULL; } KafkaBroker(KafkaBroker&& move) { this->ptr = move.ptr; this->ref = move.ref; move.ptr = new kafka_broker_t; kafka_broker_init(move.ptr); move.ref = new std::atomic(1); } KafkaBroker& operator= (KafkaBroker&& move) { if (this != &move) { this->~KafkaBroker(); this->ptr = move.ptr; this->ref = move.ref; move.ptr = new kafka_broker_t; kafka_broker_init(move.ptr); move.ref = new std::atomic(1); } return *this; } KafkaBroker(const KafkaBroker& copy) { this->ptr = copy.ptr; this->ref = copy.ref; if (this->ref) ++*this->ref; } KafkaBroker& operator= (const KafkaBroker& copy) { if (this != ©) { this->~KafkaBroker(); this->ptr = copy.ptr; this->ref = copy.ref; if (this->ref) ++*this->ref; } return *this; } bool operator< (const KafkaBroker& broker) const { return this->get_host_port() < broker.get_host_port(); } bool operator> (const KafkaBroker& broker) const { return this->get_host_port() > broker.get_host_port(); } kafka_broker_t *get_raw_ptr() const { return this->ptr; } struct list_head *get_list() { return &this->list; } struct rb_node *get_rb() { return &this->rb; } int get_node_id() const { return this->ptr->node_id; } int get_id () const { return this->ptr->node_id; } private: struct list_head list; struct rb_node rb; kafka_broker_t *ptr; std::atomic *ref; friend class KafkaList; friend class KafkaMap; }; class KafkaMeta { public: const char *get_topic() const { return this->ptr->topic_name; } const kafka_broker_t *get_broker(int partition) const { if (partition >= this->ptr->partition_elements) return NULL; for (int i = 0; i < this->ptr->partition_elements; ++i) { if (partition == this->ptr->partitions[i]->partition_index) return &this->ptr->partitions[i]->leader; } return NULL; } kafka_partition_t **get_partitions() const { return this->ptr->partitions; } int get_partition_elements() const { return this->ptr->partition_elements; } public: KafkaMeta() { this->ptr = new kafka_meta_t; kafka_meta_init(this->ptr); this->ref = new std::atomic(1); } ~KafkaMeta() { if (--*this->ref == 0) { kafka_meta_deinit(this->ptr); delete this->ptr; delete this->ref; } } KafkaMeta(KafkaMeta&& move) { this->ptr = move.ptr; this->ref = move.ref; move.ptr = new kafka_meta_t; kafka_meta_init(move.ptr); move.ref = new std::atomic(1); } KafkaMeta& operator= (KafkaMeta&& move) { if (this != &move) { this->~KafkaMeta(); this->ptr = move.ptr; this->ref = move.ref; move.ptr = new kafka_meta_t; kafka_meta_init(move.ptr); move.ref = new std::atomic(1); } return *this; } KafkaMeta(const KafkaMeta& copy) { this->ptr = copy.ptr; this->ref = copy.ref; ++*this->ref; } KafkaMeta& operator= (const KafkaMeta& copy) { if (this != ©) { this->~KafkaMeta(); this->ptr = copy.ptr; this->ref = copy.ref; ++*this->ref; } return *this; } kafka_meta_t *get_raw_ptr() { return this->ptr; } bool set_topic(const std::string& topic) { return kafka_meta_set_topic(topic.c_str(), this->ptr) == 0; } struct list_head *get_list() { return &this->list; } int get_error() const { return this->ptr->error; } bool create_partitions(int partition_cnt); bool create_replica_nodes(int partition_idx, int replica_cnt) { int *replica_nodes = (int *)malloc(replica_cnt * 4); if (!replica_nodes) return false; this->ptr->partitions[partition_idx]->replica_nodes = replica_nodes; this->ptr->partitions[partition_idx]->replica_node_elements = replica_cnt; return true; } bool create_isr_nodes(int partition_idx, int isr_cnt) { int *isr_nodes = (int *)malloc(isr_cnt * 4); if (!isr_nodes) return false; this->ptr->partitions[partition_idx]->isr_nodes = isr_nodes; this->ptr->partitions[partition_idx]->isr_node_elements = isr_cnt; return true; } private: struct list_head list; kafka_meta_t *ptr; std::atomic *ref; friend class KafkaList; friend const KafkaMeta *get_meta(const char *topic, KafkaMetaList *meta_list); }; class KafkaMetaSubscriber { public: void set_meta(KafkaMeta *meta) { this->meta = meta; } const KafkaMeta *get_meta() const { return this->meta; } void add_member(kafka_member_t *member) { this->member_vec.push_back(member); } const std::vector *get_member() const { return &this->member_vec; } void sort_by_member(); private: KafkaMeta *meta; std::vector member_vec; }; class KafkaCgroup { public: const char *get_group() const { return this->ptr->group_name; } const char *get_protocol_type() const { return this->ptr->protocol_type; } const char *get_protocol_name() const { return this->ptr->protocol_name; } int get_generation_id() const { return this->ptr->generation_id; } const char *get_member_id() const { return this->ptr->member_id; } public: KafkaCgroup(); ~KafkaCgroup(); KafkaCgroup(KafkaCgroup&& move); KafkaCgroup& operator= (KafkaCgroup&& move); KafkaCgroup(const KafkaCgroup& copy); KafkaCgroup& operator= (const KafkaCgroup& copy); kafka_cgroup_t *get_raw_ptr() { return this->ptr; } void set_group(const std::string& group) { char *p = new char[group.size() + 1]; strncpy(p, group.c_str(), group.size()); p[group.size()] = 0; this->ptr->group_name = p; } struct list_head *get_list() { return &this->list; } int get_error() const { return this->ptr->error; } bool is_leader() const { return strcmp(this->ptr->leader_id, this->ptr->member_id) == 0; } struct list_head *get_group_protocol() { return &this->ptr->group_protocol_list; } void set_member_id(const char *p) { free(this->ptr->member_id); this->ptr->member_id = strdup(p); } void set_error(short error) { this->ptr->error = error; } bool create_members(int member_cnt); kafka_member_t **get_members() const { return this->ptr->members; } int get_member_elements() { return this->ptr->member_elements; } void add_assigned_toppar(KafkaToppar *toppar); struct list_head *get_assigned_toppar_list() { return &this->ptr->assigned_toppar_list; } KafkaToppar *get_assigned_toppar_by_pos(struct list_head *pos) { return list_entry(pos, KafkaToppar, list); } void assigned_toppar_rewind(); KafkaToppar *get_assigned_toppar_next(); void del_assigned_toppar_cur(); KafkaBroker *get_coordinator() { if (!this->coordinator) this->coordinator = new KafkaBroker(&this->ptr->coordinator); return this->coordinator; } int run_assignor(KafkaMetaList *meta_list, const char *protocol_name); void add_subscriber(KafkaMetaList *meta_list, std::vector *subscribers); static int kafka_range_assignor(kafka_member_t **members, int member_elements, void *meta_topic); static int kafka_roundrobin_assignor(kafka_member_t **members, int member_elements, void *meta_topic); private: struct list_head list; kafka_cgroup_t *ptr; std::atomic *ref; struct list_head *curpos; KafkaBroker *coordinator; }; class KafkaBlock { public: KafkaBlock() { this->ptr = new kafka_block_t; kafka_block_init(this->ptr); } ~KafkaBlock() { kafka_block_deinit(this->ptr); delete this->ptr; } KafkaBlock(KafkaBlock&& move) { this->ptr = move.ptr; move.ptr = new kafka_block_t; kafka_block_init(move.ptr); } KafkaBlock& operator= (KafkaBlock&& move) { if (this != &move) { this->~KafkaBlock(); this->ptr = move.ptr; move.ptr = new kafka_block_t; kafka_block_init(move.ptr); } return *this; } kafka_block_t *get_raw_ptr() const { return this->ptr; } struct list_head *get_list() { return &this->list; } void *get_block() const { return this->ptr->buf; } size_t get_len() const { return this->ptr->len; } bool allocate(size_t len) { void *p = malloc(len); if (!p) return false; free(this->ptr->buf); this->ptr->buf = p; this->ptr->len = len; return true; } bool reallocate(size_t len) { void *p = realloc(this->ptr->buf, len); if (p) { this->ptr->buf = p; this->ptr->len = len; return true; } else return false; } bool set_block(void *buf, size_t len) { if (!this->allocate(len)) return false; memcpy(this->ptr->buf, buf, len); return true; } void set_block_nocopy(void *buf, size_t len) { this->ptr->buf = buf; this->ptr->len = len; } void set_len(size_t len) { this->ptr->len = len; } private: struct list_head list; kafka_block_t *ptr; friend class KafkaBuffer; friend class KafkaList; }; class KafkaBuffer { public: KafkaBuffer() { this->insert_pos = NULL; this->insert_curpos = NULL; this->buf_size = 0; this->inited = false; this->insert_buf_size = 0; this->insert_flag = false; } void backup(size_t n) { this->buf_size -= n; } void list_splice(KafkaBuffer *buffer); void add_item(KafkaBlock block) { if (this->insert_flag) this->insert_buf_size += block.get_len(); this->buf_size += block.get_len(); this->block_list.add_item(std::move(block)); } void set_insert_pos() { this->insert_pos = this->block_list.get_tail(); this->insert_flag = true; this->insert_buf_size = 0; } void block_insert_rewind() { this->insert_flag = false; this->insert_curpos = this->insert_pos; } KafkaBlock *get_block_insert_next() { if (this->insert_curpos->next == this->block_list.get_head()) return NULL; this->insert_curpos = this->insert_curpos->next; return list_entry(this->insert_curpos, KafkaBlock, list); } KafkaBlock *get_block_tail() { return this->block_list.get_tail_entry(); } void insert_list(KafkaBlock *block) { this->buf_size += block->get_len(); this->block_list.insert_pos(block->get_list(), this->insert_pos); this->insert_pos = this->insert_pos->next; } KafkaBlock *get_block_first() { this->block_list.rewind(); return this->block_list.get_next(); } KafkaBlock *get_block_next() { return this->block_list.get_next(); } void append(const char *bytes, size_t n) { KafkaBlock block; block.set_block((void *)bytes, n); this->block_list.add_item(std::move(block)); this->buf_size += n; } size_t get_size() const { return this->buf_size; } size_t peek(const char **buf); long seek(long offset) { this->cur_pos.second += offset; return offset; } struct list_head *get_head() { return this->block_list.get_head(); } private: KafkaList block_list; std::pair cur_pos; struct list_head *insert_pos; struct list_head *insert_curpos; size_t buf_size; size_t insert_buf_size; bool inited; bool insert_flag; }; class KafkaSnappySink : public snappy::Sink { public: KafkaSnappySink(KafkaBuffer *buffer) { this->buffer = buffer; } virtual void Append(const char *bytes, size_t n) { this->buffer->append(bytes, n); } size_t size() const { return this->buffer->get_size(); } KafkaBuffer *get_buffer() const { return buffer; } private: KafkaBuffer *buffer; }; class KafkaSnappySource : public snappy::Source { public: KafkaSnappySource(KafkaBuffer *buffer) { this->buffer = buffer; this->buf_size = this->buffer->get_size(); this->pos = 0; } virtual size_t Available() const { return this->buf_size - this->pos; } virtual const char *Peek(size_t *len) { const char *pos; *len = this->buffer->peek(&pos); return pos; } virtual void Skip(size_t n) { this->pos += this->buffer->seek(n); } private: KafkaBuffer *buffer; size_t buf_size; size_t pos; }; } #endif workflow-0.11.8/src/protocol/KafkaMessage.cc000066400000000000000000002417261476003635400207600ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wang Zhulei (wangzhulei@sogou-inc.com) */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "crc32c.h" #include "EncodeStream.h" #include "KafkaMessage.h" namespace protocol { #define CHECK_RET(exp) \ do { \ int tmp = exp; \ if (tmp < 0) \ return tmp; \ } while (0) #ifndef htonll static uint64_t htonll(uint64_t x) { if (1 == htonl(1)) return x; else return ((uint64_t)htonl(x & 0xFFFFFFFF) << 32) + htonl(x >> 32); } #endif static size_t append_bool(std::string& buf, bool val) { unsigned char v = 0; if (val) v = 1; buf.append((char *)&v, 1); return 1; } static size_t append_i8(std::string& buf, int8_t val) { buf.append((char *)&val, 1); return 1; } static size_t append_i8(void **buf, int8_t val) { *(char *)*buf = val; *buf = (char *)*buf + 1; return 1; } static size_t append_i16(std::string& buf, int16_t val) { int16_t v = htons(val); buf.append((char *)&v, 2); return 2; } static size_t append_i32(std::string& buf, int32_t val) { int32_t v = htonl(val); buf.append((char *)&v, 4); return 4; } static size_t append_i32(void **buf, int32_t val) { int32_t v = htonl(val); *(int32_t *)*buf = v; *buf = (int32_t *)*buf + 1; return 4; } static size_t append_i64(std::string& buf, int64_t val) { int64_t v = htonll(val); buf.append((char *)&v, 8); return 8; } static size_t append_i64(void **buf, int64_t val) { int64_t v = htonll(val); *(int64_t *)*buf = v; *buf = (int64_t *)*buf + 1; return 8; } static size_t append_string(std::string& buf, const char *str, size_t len) { append_i16(buf, len); buf.append(str, len); return len + 2; } static size_t append_string(std::string& buf, const char *str) { if (!str) return append_string(buf, "", 0); return append_string(buf, str, strlen(str)); } static size_t append_string_raw(std::string& buf, const char *str, size_t len) { buf.append(str, len); return len; } static size_t append_nullable_string(std::string& buf, const char *str, size_t len) { if (len == 0) return append_i16(buf, -1); else return append_string(buf, str, len); } static size_t append_string_raw(void **buf, const char *str, size_t len) { memcpy(*buf, str, len); *buf = (char *)*buf + len; return len; } static size_t append_string_raw(void **buf, const std::string& str) { return append_string_raw(buf, str.c_str(), str.size()); } static size_t append_bytes(std::string& buf, const char *str, size_t len) { append_i32(buf, len); buf.append(str, len); return 4 + len; } static size_t append_bytes(std::string& buf, const std::string& str) { return append_bytes(buf, str.c_str(), str.size()); } static size_t append_bytes(void **buf, const char *str, size_t len) { *((int32_t *)*buf) = htonl(len); *buf = (int32_t *)*buf + 1; memcpy(*buf, str, len); *buf = (char *)*buf + len; return len + 2; } static size_t append_nullable_bytes(void **buf, const char *str, size_t len) { if (len == 0) return append_i32(buf, -1); else return append_bytes(buf, str, len); } static size_t append_varint_u64(std::string& buf, uint64_t num) { size_t len = 0; do { unsigned char v = (num & 0x7f) | (num > 0x7f ? 0x80 : 0); buf.append((char *)&v, 1); num >>= 7; ++len; } while (num); return len; } static inline size_t append_varint_i64(std::string& buf, int64_t num) { return append_varint_u64(buf, (num << 1) ^ (num >> 63)); } static inline size_t append_varint_i32(std::string& buf, int32_t num) { return append_varint_i64(buf, num); } static size_t append_compact_string(std::string& buf, const char *str) { if (!str || str[0] == '\0') append_string(buf, ""); size_t len = strlen(str); size_t r = append_varint_u64(buf, len + 1); append_string_raw(buf, str, len); return r + len; } static inline int parse_i8(void **buf, size_t *size, int8_t *val) { if (*size >= 1) { *val = *(int8_t *)*buf; *size -= sizeof(int8_t); *buf = (int8_t *)*buf + 1; return 0; } errno = EBADMSG; return -1; } static inline int parse_i16(void **buf, size_t *size, int16_t *val) { if (*size >= 2) { *val = ntohs(*(int16_t *)*buf); *size -= sizeof(int16_t); *buf = (int16_t *)*buf + 1; return 0; } errno = EBADMSG; return -1; } static inline int parse_i32(void **buf, size_t *size, int32_t *val) { if (*size >= 4) { *val = ntohl(*(int32_t *)*buf); *size -= sizeof(int32_t); *buf = (int32_t *)*buf + 1; return 0; } errno = EBADMSG; return -1; } static inline int parse_i64(void **buf, size_t *size, int64_t *val) { if (*size >= 8) { *val = htonll(*(int64_t *)*buf); *size -= sizeof(int64_t); *buf = (int64_t *)*buf + 1; return 0; } errno = EBADMSG; return -1; } static int parse_string(void **buf, size_t *size, std::string& str); static int parse_string(void **buf, size_t *size, char **str); static int parse_bytes(void **buf, size_t *size, std::string& str); static int parse_bytes(void **buf, size_t *size, void **str, size_t *str_len); static int parse_varint_u64(void **buf, size_t *size, uint64_t *val); static int parse_varint_i64(void **buf, size_t *size, int64_t *val) { uint64_t n; int ret = parse_varint_u64(buf, size, &n); if (ret == 0) *val = (int64_t)(n >> 1) ^ -(int64_t)(n & 1); return ret; } static int parse_varint_i32(void **buf, size_t *size, int32_t *val) { int64_t v = 0; if (parse_varint_i64(buf, size, &v) < 0) return -1; *val = (int32_t)v; return 0; } static const LZ4F_preferences_t kPrefs = { .frameInfo = {LZ4F_default, LZ4F_blockIndependent, }, .compressionLevel = 0, }; static int compress_buf(KafkaBlock *block, int compress_type, void *env) { z_stream *c_stream; size_t total_in = 0, gzip_in, bound_size; KafkaBuffer *snappy_buffer; KafkaBlock nblock; LZ4F_errorCode_t lz4_r; LZ4F_cctx *lz4_cctx; ZSTD_CStream *zstd_cctx; size_t zstd_r; ZSTD_outBuffer out; ZSTD_inBuffer in; switch (compress_type) { case Kafka_Gzip: c_stream = static_cast(env); gzip_in = c_stream->total_in; while (total_in < block->get_len()) { if (c_stream->avail_in == 0) { c_stream->next_in = (Bytef *)block->get_block(); c_stream->avail_in = block->get_len() - total_in; } if (c_stream->avail_out == 0) { bound_size = compressBound(c_stream->avail_in); if (!nblock.allocate(bound_size)) { delete c_stream; return -1; } c_stream->next_out = (Bytef *)nblock.get_block(); c_stream->avail_out = bound_size; } if (deflate(c_stream, Z_NO_FLUSH) != Z_OK) { delete c_stream; errno = EBADMSG; return -1; } total_in += c_stream->total_in - gzip_in; gzip_in = c_stream->total_in; } *block = std::move(nblock); break; case Kafka_Snappy: snappy_buffer = static_cast(env); snappy_buffer->append((const char *)block->get_block(), block->get_len()); break; case Kafka_Lz4: lz4_cctx = static_cast(env); bound_size = LZ4F_compressBound(block->get_len(), &kPrefs); if (!nblock.allocate(bound_size)) { LZ4F_freeCompressionContext(lz4_cctx); return -1; } lz4_r = LZ4F_compressUpdate(lz4_cctx, nblock.get_block(), nblock.get_len(), block->get_block(), block->get_len(), NULL); if (LZ4F_isError(lz4_r)) { LZ4F_freeCompressionContext(lz4_cctx); errno = EBADMSG; return -1; } nblock.set_len(lz4_r); *block = std::move(nblock); break; case Kafka_Zstd: zstd_cctx = static_cast(env); bound_size = ZSTD_compressBound(block->get_len()); if (!nblock.allocate(bound_size)) { ZSTD_freeCStream(zstd_cctx); return -1; } in.src = block->get_block(); in.pos = 0; in.size = block->get_len(); out.dst = nblock.get_block(); out.pos = 0; out.size = nblock.get_len(); zstd_r = ZSTD_compressStream(zstd_cctx, &out, &in); if (ZSTD_isError(zstd_r) || in.pos < in.size) { ZSTD_freeCStream(zstd_cctx); errno = EBADMSG; return -1; } nblock.set_len(out.pos); *block = std::move(nblock); break; default: return 0; } return 0; } static int gzip_decompress(void *compressed, size_t n, KafkaBlock *block) { for (int pass = 1; pass <= 2; pass++) { z_stream strm = {0}; gz_header hdr; char buf[512]; char *p; int len; int r; if ((r = inflateInit2(&strm, 15 | 32)) != Z_OK) { errno = EBADMSG; return -1; } strm.next_in = (Bytef *)compressed; strm.avail_in = n; if ((r = inflateGetHeader(&strm, &hdr)) != Z_OK) { inflateEnd(&strm); errno = EBADMSG; return -1; } if (pass == 1) { p = buf; len = sizeof(buf); } else { p = (char *)block->get_block(); len = block->get_len(); } do { strm.next_out = (unsigned char *)p; strm.avail_out = len; r = inflate(&strm, Z_NO_FLUSH); switch (r) { case Z_STREAM_ERROR: case Z_NEED_DICT: case Z_DATA_ERROR: case Z_MEM_ERROR: inflateEnd(&strm); errno = EBADMSG; return -1; } if (pass == 2) { p += len - strm.avail_out; len -= len - strm.avail_out; } } while (strm.avail_out == 0 && r != Z_STREAM_END); if (pass == 1) { if (!block->allocate(strm.total_out)) { inflateEnd(&strm); return -1; } } inflateEnd(&strm); if (strm.total_in != n || r != Z_STREAM_END) { errno = EBADMSG; return -1; } } return 0; } static int kafka_snappy_java_uncompress(const char *inbuf, size_t inlen, KafkaBlock *block) { char *obuf = NULL; for (int pass = 1; pass <= 2; pass++) { ssize_t off = 0; ssize_t uoff = 0; while (off + 4 <= (ssize_t)inlen) { uint32_t clen; size_t ulen; memcpy(&clen, inbuf + off, 4); clen = ntohl(clen); off += 4; if (clen > inlen - off) { errno = EBADMSG; return -1; } if (snappy_uncompressed_length(inbuf + off, clen, &ulen) != SNAPPY_OK) { errno = EBADMSG; return -1; } if (pass == 1) { off += clen; uoff += ulen; continue; } size_t n = block->get_len() - uoff; if (snappy_uncompress(inbuf + off, clen, obuf + uoff, &n) != SNAPPY_OK) { errno = EBADMSG; return -1; } off += clen; uoff += ulen; } if (off != (ssize_t)inlen) { errno = EBADMSG; return -1; } if (pass == 1) { if (uoff <= 0) { errno = EBADMSG; return -1; } if (!block->allocate(uoff)) return -1; obuf = (char *)block->get_block(); } else block->set_len(uoff); } return 0; } static int snappy_decompress(void *buf, size_t n, KafkaBlock *block) { const char *inbuf = (const char *)buf; size_t inlen = n; static const unsigned char snappy_java_magic[] = { 0x82, 'S','N','A','P','P','Y', 0 }; static const size_t snappy_java_hdrlen = 8 + 4 + 4; if (!memcmp(buf, snappy_java_magic, 8)) { inbuf = inbuf + snappy_java_hdrlen; inlen -= snappy_java_hdrlen; return kafka_snappy_java_uncompress(inbuf, inlen, block); } else { size_t uncompressed_len; if (snappy_uncompressed_length(inbuf, n, &uncompressed_len) != SNAPPY_OK) { errno = EBADMSG; return -1; } if (!block->allocate(uncompressed_len)) return -1; size_t nn = block->get_len(); return snappy_uncompress(inbuf, n, (char *)block->get_block(), &nn); } } static int lz4_decompress(void *buf, size_t n, KafkaBlock *block) { LZ4F_errorCode_t code; LZ4F_decompressionContext_t dctx; LZ4F_frameInfo_t fi; size_t in_sz, out_sz; size_t in_off, out_off; size_t r; size_t uncompressed_size; size_t outlen; char *out = NULL; size_t inlen = n; const char *inbuf = (const char *)buf; code = LZ4F_createDecompressionContext(&dctx, LZ4F_VERSION); if (LZ4F_isError(code)) { code = LZ4F_freeDecompressionContext(dctx); errno = EBADMSG; return -1; } in_sz = n; r = LZ4F_getFrameInfo(dctx, &fi, (const void *)buf, &in_sz); if (LZ4F_isError(r)) { code = LZ4F_freeDecompressionContext(dctx); errno = EBADMSG; return -1; } if (fi.contentSize == 0 || fi.contentSize > inlen * 255) uncompressed_size = inlen * 4; else uncompressed_size = (size_t)fi.contentSize; if (!block->allocate(uncompressed_size)) { code = LZ4F_freeDecompressionContext(dctx); return -1; } out = (char *)block->get_block(); outlen = block->get_len(); in_off = in_sz; out_off = 0; while (in_off < inlen) { out_sz = outlen - out_off; in_sz = inlen - in_off; r = LZ4F_decompress(dctx, out + out_off, &out_sz, inbuf + in_off, &in_sz, NULL); if (LZ4F_isError(r)) { code = LZ4F_freeDecompressionContext(dctx); errno = EBADMSG; return -1; } if (!(out_off + out_sz <= outlen && in_off + in_sz <= inlen)) { code = LZ4F_freeDecompressionContext(dctx); errno = EBADMSG; return -1; } out_off += out_sz; in_off += in_sz; if (r == 0) break; if (out_off == outlen) { size_t extra = outlen * 3 / 4; if (!block->reallocate(outlen + extra)) { code = LZ4F_freeDecompressionContext(dctx); errno = EBADMSG; return -1; } out = (char *)block->get_block(); outlen += extra; } } if (in_off < inlen) { code = LZ4F_freeDecompressionContext(dctx); errno = EBADMSG; return -1; } LZ4F_freeDecompressionContext(dctx); return 0; } static int zstd_decompress(void *buf, size_t n, KafkaBlock *block) { unsigned long long out_bufsize = ZSTD_getFrameContentSize(buf, n); switch (out_bufsize) { case ZSTD_CONTENTSIZE_UNKNOWN: out_bufsize = n * 2; break; case ZSTD_CONTENTSIZE_ERROR: errno = EBADMSG; return -1; default: break; } while (1) { size_t ret; if (!block->allocate(out_bufsize)) return -1; ret = ZSTD_decompress(block->get_block(), out_bufsize, buf, n); if (!ZSTD_isError(ret)) return 0; if (ZSTD_getErrorCode(ret) == ZSTD_error_dstSize_tooSmall) { out_bufsize += out_bufsize * 2; } else { errno = EBADMSG; return -1; } } } static int uncompress_buf(void *buf, size_t size, KafkaBlock *block, int compress_type) { switch(compress_type) { case Kafka_Gzip: return gzip_decompress(buf, size, block); case Kafka_Snappy: return snappy_decompress(buf, size, block); case Kafka_Lz4: return lz4_decompress(buf, size, block); case Kafka_Zstd: return zstd_decompress(buf, size, block); default: errno = EBADMSG; return -1; } } static int append_message_set(KafkaBlock *block, const KafkaRecord *record, int offset, int msg_version, const KafkaConfig& config, void *env, int cur_msg_size) { const void *key; size_t key_len; record->get_key(&key, &key_len); const void *value; size_t value_len; record->get_value(&value, &value_len); int message_size = 4 + 1 + 1 + 4 + 4 + key_len + value_len; if (msg_version == 1) message_size += 8; int max_msg_size = std::min(config.get_produce_msgset_max_bytes(), config.get_produce_msg_max_bytes()); if (message_size + 8 + 4 + cur_msg_size > max_msg_size) return 1; if (!block->allocate(message_size + 8 + 4)) return -1; void *cur = block->get_block(); append_i64(&cur, offset); append_i32(&cur, message_size); int crc_32 = crc32(0, NULL, 0); append_i32(&cur, crc_32); //need update append_i8(&cur, msg_version); append_i8(&cur, 0); if (msg_version == 1) append_i64(&cur, record->get_timestamp()); append_bytes(&cur, (const char *)key, key_len); append_nullable_bytes(&cur, (const char *)value, value_len); char *crc_buf = (char *)block->get_block() + 8 + 4; crc_32 = crc32(crc_32, (Bytef *)(crc_buf + 4), message_size - 4); *(uint32_t *)crc_buf = htonl(crc_32); if (compress_buf(block, config.get_compress_type(), env) < 0) return -1; return 0; } static int append_batch_record(KafkaBlock *block, const KafkaRecord *record, int offset, const KafkaConfig& config, int64_t first_timestamp, void *env, int cur_msg_size) { const void *key; size_t key_len; record->get_key(&key, &key_len); const void *value; size_t value_len; record->get_value(&value, &value_len); std::string klen_str; std::string vlen_str; std::string timestamp_delta_str; int64_t timestamp_delta = record->get_timestamp() - first_timestamp; append_varint_i64(timestamp_delta_str, timestamp_delta); std::string offset_delta_str; append_varint_i64(offset_delta_str, offset); if (key_len > 0) append_varint_i32(klen_str, (int32_t)key_len); else append_varint_i32(klen_str, (int32_t)-1); if (value) append_varint_i32(vlen_str, (int32_t)value_len); else append_varint_i32(vlen_str, -1); struct list_head *pos; kafka_record_header_t *header; std::string hdr_str; int hdr_cnt = 0; list_for_each(pos, record->get_header_list()) { header = list_entry(pos, kafka_record_header_t, list); append_varint_i32(hdr_str, (int32_t)header->key_len); append_string_raw(hdr_str, (const char *)header->key, header->key_len); append_varint_i32(hdr_str, (int32_t)header->value_len); append_string_raw(hdr_str, (const char *)header->value, header->value_len); ++hdr_cnt; } std::string hdr_cnt_str; append_varint_i32(hdr_cnt_str, hdr_cnt); int length = 1 + timestamp_delta_str.size() + offset_delta_str.size() + klen_str.size() + key_len + vlen_str.size() + value_len + hdr_cnt_str.size() + hdr_str.size(); std::string length_str; append_varint_i32(length_str, length); int max_msg_size = std::min(config.get_produce_msgset_max_bytes(), config.get_produce_msg_max_bytes()); if ((int)(length + length_str.size() + cur_msg_size) > max_msg_size) return 1; if (!block->allocate(length + length_str.size())) return false; void *cur = block->get_block(); append_string_raw(&cur, length_str); append_i8(&cur, 0); append_string_raw(&cur, timestamp_delta_str); append_string_raw(&cur, offset_delta_str); append_string_raw(&cur, klen_str); if (key_len > 0) append_string_raw(&cur, (const char *)key, key_len); append_string_raw(&cur, vlen_str); if (value_len > 0) append_string_raw(&cur, (const char *)value, value_len); append_string_raw(&cur, hdr_cnt_str); if (hdr_cnt > 0) append_string_raw(&cur, hdr_str); if (compress_buf(block, config.get_compress_type(), env) < 0) return -1; return 0; } static int append_record(KafkaBlock *block, const KafkaRecord *record, int offset, int msg_version, const KafkaConfig& config, int64_t first_timestamp, void *env, int cur_msg_size) { if (config.get_produce_msgset_cnt() < offset) return 1; int ret = 0; switch (msg_version) { case 0: case 1: ret = append_message_set(block, record, offset, msg_version, config, env, cur_msg_size); break; case 2: ret = append_batch_record(block, record, offset, config, first_timestamp, env, cur_msg_size); break; default: break; } return ret; } static int parse_string(void **buf, size_t *size, std::string& str) { if (*size >= 2) { int16_t len; if (parse_i16(buf, size, &len) >= 0) { if (len >= -1) { if (len == -1) len = 0; if (*size >= (size_t)len) { str.assign((char *)*buf, len); *size -= len; *buf = (char *)*buf + len; return 0; } else { *buf = (char *)*buf - 2; *size += 2; } } } } errno = EBADMSG; return -1; } static int parse_string(void **buf, size_t *size, char **str) { if (*size >= 2) { int16_t len; if (parse_i16(buf, size, &len) >= 0) { if (len >= -1) { if (len == -1) len = 0; if (*size >= (size_t)len) { char *p = (char *)malloc(len + 1); if (!p) { *buf = (char *)*buf - 2; *size += 2; return -1; } free(*str); memcpy((void *)p, *buf, len); p[len] = 0; *size -= len; *buf = (char *)*buf + len; *str = p; return 0; } else { *buf = (char *)*buf - 2; *size += 2; } } } } errno = EBADMSG; return -1; } static int parse_bytes(void **buf, size_t *size, std::string& str) { if (*size >= 4) { int32_t len; if (parse_i32(buf, size, &len) >= 0) { if (len == -1) len = 0; if (*size >= (size_t)len) { str.assign((char *)*buf, len); *size -= len; *buf = (char *)*buf + len; return 0; } else { *buf = (char *)*buf - 4; *size += 4; } } } errno = EBADMSG; return -1; } static int parse_bytes(void **buf, size_t *size, void **str, size_t *str_len) { if (*size >= 4) { int32_t len; if (parse_i32(buf, size, &len) >= 0) { if (len == -1) len = 0; if (*size >= (size_t)len) { *str = *buf; *str_len = len; *size -= len; *buf = (char *)*buf + len; return 0; } else { *buf = (char *)*buf - 4; *size += 4; } } } errno = EBADMSG; return -1; } static int parse_varint_u64(void **buf, size_t *size, uint64_t *val) { size_t off = 0; uint64_t num = 0; int shift = 0; size_t org_size = *size; do { if (*size == 0) { *size = org_size; errno = EBADMSG; return -1; /* Underflow */ } num |= (uint64_t)(((char *)(*buf))[(int)off] & 0x7f) << shift; shift += 7; } while (((char *)(*buf))[(int)off++] & 0x80); *val = num; *buf = (char *)(*buf) + off; *size -= off; return 0; } int KafkaMessage::parse_message_set(void **buf, size_t *size, bool check_crcs, int msg_vers, struct list_head *record_list, KafkaBuffer *uncompressed, KafkaToppar *toppar) { int64_t offset; int32_t message_size; int32_t crc; if (parse_i64(buf, size, &offset) < 0) return -1; if (parse_i32(buf, size, &message_size) < 0) return -1; if (*size < (size_t)(message_size - 8)) return 1; if (parse_i32(buf, size, &crc) < 0) return -1; if (check_crcs) { int32_t crc_32 = crc32(0, NULL, 0); crc_32 = crc32(crc_32, (Bytef *)*buf, message_size - 4); if (crc_32 != crc) { errno = EBADMSG; return -1; } } int8_t magic; int8_t attributes; if (parse_i8(buf, size, &magic) < 0) return -1; if (parse_i8(buf, size, &attributes) < 0) return -1; int64_t timestamp = -1; if (msg_vers == 1 && parse_i64(buf, size, ×tamp) < 0) return -1; void *key; size_t key_len; if (parse_bytes(buf, size, &key, &key_len) < 0) return -1; void *payload; size_t payload_len; if (parse_bytes(buf, size, &payload, &payload_len) < 0) return -1; if (offset >= toppar->get_offset()) { int compress_type = attributes & 3; if (compress_type == 0) { KafkaRecord *kafka_record = new KafkaRecord; kafka_record_t *record = kafka_record->get_raw_ptr(); record->key = key; record->key_len = key_len; record->timestamp = timestamp; record->offset = offset; record->toppar = toppar->get_raw_ptr(); record->key_is_moved = 1; record->value_is_moved = 1; record->value = payload; record->value_len = payload_len; list_add_tail(kafka_record->get_list(), record_list); } else { KafkaBlock block; if (uncompress_buf(payload, payload_len, &block, compress_type) < 0) return -1; struct list_head *record_head = record_list->prev; void *uncompressed_ptr = block.get_block(); size_t uncompressed_len = block.get_len(); parse_message_set(&uncompressed_ptr, &uncompressed_len, check_crcs, msg_vers, record_list, uncompressed, toppar); uncompressed->add_item(std::move(block)); if (msg_vers == 1) { struct list_head *pos; KafkaRecord *record; int n = 0; for (pos = record_head->next; pos != record_list; pos = pos->next) n++; for (pos = record_head->next; pos != record_list; pos = pos->next) { int64_t fix_offset; record = list_entry(pos, KafkaRecord, list); fix_offset = offset + record->get_offset() - n + 1; record->set_offset(fix_offset); } } } } if (*size > 0) { return parse_message_set(buf, size, check_crcs, msg_vers, record_list, uncompressed, toppar); } return 0; } static int parse_varint_bytes(void **buf, size_t *size, void **str, size_t *str_len) { int64_t len = 0; if (parse_varint_i64(buf, size, &len) >= 0) { if (len >= -1) { if (len <= 0) { *str = NULL; *str_len = 0; return 0; } if ((int64_t)*size >= len) { *str = *buf; *str_len = (size_t)len; *size -= len; *buf = (char *)*buf + len; return 0; } } } errno = EBADMSG; return -1; } struct KafkaBatchRecordHeader { int64_t base_offset; int32_t length; int32_t partition_leader_epoch; int8_t magic; int32_t crc; int16_t attributes; int32_t last_offset_delta; int64_t base_timestamp; int64_t max_timestamp; int64_t produce_id; int16_t producer_epoch; int32_t base_sequence; int32_t record_count; }; int KafkaMessage::parse_message_record(void **buf, size_t *size, kafka_record_t *record) { int64_t length; int8_t attributes; int64_t timestamp_delta; int64_t offset_delta; int32_t hdr_size; if (parse_varint_i64(buf, size, &length) < 0) return -1; if (parse_i8(buf, size, &attributes) < 0) return -1; if (parse_varint_i64(buf, size, ×tamp_delta) < 0) return -1; if (parse_varint_i64(buf, size, &offset_delta) < 0) return -1; record->timestamp += timestamp_delta; record->offset += offset_delta; if (parse_varint_bytes(buf, size, &record->key, &record->key_len) < 0) return -1; if (parse_varint_bytes(buf, size, &record->value, &record->value_len) < 0) return -1; if (parse_varint_i32(buf, size, &hdr_size) < 0) return -1; for (int i = 0; i < hdr_size; ++i) { kafka_record_header_t *header; header = (kafka_record_header_t *)malloc(sizeof(kafka_record_header_t)); if (!header) return -1; kafka_record_header_init(header); if (parse_varint_bytes(buf, size, &header->key, &header->key_len) < 0) { free(header); return -1; } if (parse_varint_bytes(buf, size, &header->value, &header->value_len) < 0) { kafka_record_header_deinit(header); free(header); return -1; } header->key_is_moved = 1; header->value_is_moved = 1; list_add_tail(&header->list, &record->header_list); } return record->offset < record->toppar->offset ? 1 : 0; } int KafkaMessage::parse_record_batch(void **buf, size_t *size, bool check_crcs, struct list_head *record_list, KafkaBuffer *uncompressed, KafkaToppar *toppar) { KafkaBatchRecordHeader hdr; if (parse_i64(buf, size, &hdr.base_offset) < 0) return -1; if (parse_i32(buf, size, &hdr.length) < 0) return -1; if (parse_i32(buf, size, &hdr.partition_leader_epoch) < 0) return -1; if (parse_i8(buf, size, &hdr.magic) < 0) return -1; if (parse_i32(buf, size, &hdr.crc) < 0) return -1; if (check_crcs) { if (hdr.length > (int)*size + 9) { errno = EBADMSG; return -1; } if ((int)crc32c(0, (const void *)*buf, hdr.length - 9) != hdr.crc) { errno = EBADMSG; return -1; } } if (parse_i16(buf, size, &hdr.attributes) < 0) return -1; if (parse_i32(buf, size, &hdr.last_offset_delta) < 0) return -1; if (parse_i64(buf, size, &hdr.base_timestamp) < 0) return -1; if (parse_i64(buf, size, &hdr.max_timestamp) < 0) return -1; if (parse_i64(buf, size, &hdr.produce_id) < 0) return -1; if (parse_i16(buf, size, &hdr.producer_epoch) < 0) return -1; if (parse_i32(buf, size, &hdr.base_sequence) < 0) return -1; if (parse_i32(buf, size, &hdr.record_count) < 0) return -1; if (*size < (size_t)(hdr.length - 61 + 12)) return 1; KafkaBlock block; int compress_type = hdr.attributes & 7; if (compress_type != 0) { if (uncompress_buf(*buf, hdr.length - 61 + 12, &block, compress_type) < 0) return -1; *buf = (char *)*buf + hdr.length - 61 + 12; *size -= hdr.length - 61 + 12; } void *p = *buf; size_t n = *size; if (block.get_len() > 0) { p = block.get_block(); n = block.get_len(); } for (int i = 0; i < hdr.record_count; ++i) { KafkaRecord *record = new KafkaRecord; record->set_offset(hdr.base_offset); record->set_timestamp(hdr.base_timestamp); record->get_raw_ptr()->key_is_moved = 1; record->get_raw_ptr()->value_is_moved = 1; record->get_raw_ptr()->toppar = toppar->get_raw_ptr(); switch (parse_message_record(&p, &n, record->get_raw_ptr())) { case -1: delete record; return -1; case 0: list_add_tail(record->get_list(), record_list); break; default: delete record; break; } } if (compress_type == 0) { *buf = p; *size = n; } if (block.get_len() > 0) uncompressed->add_item(std::move(block)); return 0; } int KafkaMessage::parse_records(void **buf, size_t *size, bool check_crcs, KafkaBuffer *uncompressed, KafkaToppar *toppar) { struct list_head *record_list = toppar->get_record(); int msg_set_size = 0; if (parse_i32(buf, size, &msg_set_size) < 0) return -1; if (msg_set_size == 0) return 0; if (msg_set_size < 0) return -1; if (*size < 17) return -1; size_t msg_size = msg_set_size; while (msg_size > 16) { int ret = -1; char magic = ((char *)(*buf))[16]; switch(magic) { case 0: case 1: ret = parse_message_set(buf, &msg_size, check_crcs, magic, record_list, uncompressed, toppar); break; case 2: ret = parse_record_batch(buf, &msg_size, check_crcs, record_list, uncompressed, toppar); break; default: break; } if (ret > 0) { *size -= msg_set_size; *buf = (char *)*buf + msg_size; return 0; } else if (ret < 0) break; } *size -= msg_set_size; *buf = (char *)*buf + msg_size; return 0; } KafkaMessage::KafkaMessage() { static struct Crc32cInitializer { Crc32cInitializer() { crc32c_global_init(); } } initializer; this->parser = new kafka_parser_t; kafka_parser_init(this->parser); this->stream = new EncodeStream; this->api_type = Kafka_Unknown; this->correlation_id = 0; this->cur_size = 0; } KafkaMessage::~KafkaMessage() { if (this->parser) { kafka_parser_deinit(this->parser); delete this->parser; delete this->stream; } } KafkaMessage::KafkaMessage(KafkaMessage&& msg) : ProtocolMessage(std::move(msg)) { this->parser = msg.parser; this->stream = msg.stream; msg.parser = NULL; msg.stream = NULL; this->msgbuf = std::move(msg.msgbuf); this->headbuf = std::move(msg.headbuf); this->toppar_list = std::move(msg.toppar_list); this->serialized = std::move(msg.serialized); this->uncompressed = std::move(msg.uncompressed); this->api_type = msg.api_type; msg.api_type = Kafka_Unknown; this->compress_env = msg.compress_env; msg.compress_env = NULL; this->cur_size = msg.cur_size; msg.cur_size = 0; } KafkaMessage& KafkaMessage::operator= (KafkaMessage &&msg) { if (this != &msg) { *(ProtocolMessage *)this = std::move(msg); if (this->parser) { kafka_parser_deinit(this->parser); delete this->parser; delete this->stream; } this->parser = msg.parser; this->stream = msg.stream; msg.parser = NULL; msg.stream = NULL; this->msgbuf = std::move(msg.msgbuf); this->headbuf = std::move(msg.headbuf); this->toppar_list = std::move(msg.toppar_list); this->serialized = std::move(msg.serialized); this->uncompressed = std::move(msg.uncompressed); this->api_type = msg.api_type; msg.api_type = Kafka_Unknown; this->compress_env = msg.compress_env; msg.compress_env = NULL; this->cur_size = msg.cur_size; msg.cur_size = 0; } return *this; } int KafkaMessage::encode_message(int api_type, struct iovec vectors[], int max) { const auto it = this->encode_func_map.find(api_type); if (it == this->encode_func_map.cend()) return -1; return it->second(vectors, max); } static int kafka_api_get_max_ver(int api_type) { switch (api_type) { case Kafka_Metadata: return 4; case Kafka_Produce: return 7; case Kafka_Fetch: return 11; case Kafka_FindCoordinator: return 2; case Kafka_JoinGroup: return 5; case Kafka_SyncGroup: return 3; case Kafka_Heartbeat: return 3; case Kafka_OffsetFetch: return 1; case Kafka_OffsetCommit: return 7; case Kafka_ListOffsets: return 1; case Kafka_LeaveGroup: return 1; case Kafka_ApiVersions: return 0; case Kafka_SaslHandshake: return 1; case Kafka_SaslAuthenticate: return 0; case Kafka_DescribeGroups: return 0; default: return 0; } } static int kafka_get_api_version(const kafka_api_t *api, const KafkaConfig& conf, int api_type, int max_ver, int message_version) { int min_ver = 0; if (api_type == Kafka_Produce) { if (message_version == 2) min_ver = 3; else if (message_version == 1) min_ver = 1; if (conf.get_compress_type() == Kafka_Zstd) min_ver = 7; } return kafka_broker_get_api_version(api, api_type, min_ver, max_ver); } int KafkaMessage::encode_head() { if (this->api_type == Kafka_ApiVersions) this->api_version = 0; else { int max_ver = kafka_api_get_max_ver(this->api_type); if (this->api->features & KAFKA_FEATURE_MSGVER2) this->message_version = 2; else if (this->api->features & KAFKA_FEATURE_MSGVER1) this->message_version = 1; else this->message_version = 0; if (this->config.get_compress_type() == Kafka_Lz4 && !(this->api->features & KAFKA_FEATURE_LZ4)) { this->config.set_compress_type(Kafka_NoCompress); } if (this->config.get_compress_type() == Kafka_Zstd && !(this->api->features & KAFKA_FEATURE_ZSTD)) { this->config.set_compress_type(Kafka_NoCompress); } this->api_version = kafka_get_api_version(this->api, this->config, this->api_type, max_ver, this->message_version); } if (this->api_version < 0) return -1; append_i32(this->headbuf, 0); append_i16(this->headbuf, this->api_type); append_i16(this->headbuf, this->api_version); append_i32(this->headbuf, this->correlation_id); append_string(this->headbuf, this->config.get_client_id()); return 0; } int KafkaMessage::encode(struct iovec vectors[], int max) { if (encode_head() < 0) return -1; int n = encode_message(this->api_type, vectors + 1, max - 1); if (n < 0) return -1; int msg_size = this->headbuf.size() + this->cur_size - 4; *(int32_t *)this->headbuf.c_str() = htonl(msg_size); vectors[0].iov_base = (void *)this->headbuf.c_str(); vectors[0].iov_len = this->headbuf.size(); return n + 1; } int KafkaMessage::append(const void *buf, size_t *size) { int ret = kafka_parser_append_message(buf, size, this->parser); if (ret >= 0) { this->cur_size += *size; if (this->cur_size > this->size_limit) { errno = EMSGSIZE; ret = -1; } } else if (ret == -2) { errno = EBADMSG; ret = -1; } return ret; } static int kafka_compress_prepare(int compress_type, void **env, KafkaBlock *block) { z_stream *c_stream; KafkaBuffer *snappy_buffer; size_t lz4_out_len; LZ4F_errorCode_t lz4_r; LZ4F_cctx *lz4_cctx = NULL; ZSTD_CStream *zstd_cctx; size_t zstd_r; switch (compress_type) { case Kafka_Gzip: c_stream = new z_stream; c_stream->zalloc = (alloc_func)0; c_stream->zfree = (free_func)0; c_stream->opaque = (voidpf)0; if (deflateInit2(c_stream, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 15 | 16, 8, Z_DEFAULT_STRATEGY) != Z_OK) { delete c_stream; errno = EBADMSG; return -1; } c_stream->avail_in = 0; c_stream->avail_out = 0; c_stream->total_in = 0; *env = (void *)c_stream; break; case Kafka_Snappy: snappy_buffer = new KafkaBuffer; *env = (void *)snappy_buffer; break; case Kafka_Lz4: lz4_r = LZ4F_createCompressionContext(&lz4_cctx, LZ4F_VERSION); if (LZ4F_isError(lz4_r)) { LZ4F_freeCompressionContext(lz4_cctx); errno = EBADMSG; return -1; } lz4_out_len = LZ4F_HEADER_SIZE_MAX; if (!block->allocate(lz4_out_len)) { LZ4F_freeCompressionContext(lz4_cctx); return -1; } lz4_r = LZ4F_compressBegin(lz4_cctx, block->get_block(), block->get_len(), &kPrefs); if (LZ4F_isError(lz4_r)) { LZ4F_freeCompressionContext(lz4_cctx); errno = EBADMSG; return -1; } block->set_len(lz4_r); *env = (void *)lz4_cctx; break; case Kafka_Zstd: zstd_cctx = ZSTD_createCStream(); if (!zstd_cctx) return -1; zstd_r = ZSTD_initCStream(zstd_cctx, ZSTD_CLEVEL_DEFAULT); if (ZSTD_isError(zstd_r)) { ZSTD_freeCStream(zstd_cctx); errno = EBADMSG; return -1; } *env = (void *)zstd_cctx; break; default: return 0; } return 0; } static int kafka_compress_finish(int compress_type, void *env, KafkaBuffer *buffer, int *addon) { int gzip_err; LZ4F_cctx *lz4_cctx; z_stream *c_stream; size_t out_len; KafkaBuffer *snappy_buffer; LZ4F_errorCode_t lz4_r; ZSTD_CStream *zstd_cctx; size_t zstd_r; ZSTD_outBuffer out; KafkaBlock block; size_t zstd_end_bufsize = ZSTD_compressBound(buffer->get_size()); switch (compress_type) { case Kafka_Gzip: c_stream = static_cast(env); out_len = c_stream->total_out; for(;;) { if (c_stream->avail_out == 0) { block.allocate(1024); c_stream->next_out = (Bytef *)block.get_block(); c_stream->avail_out = 1024; } gzip_err = deflate(c_stream, Z_FINISH); if (gzip_err == Z_STREAM_END) break; if (gzip_err != Z_OK) { delete c_stream; errno = EBADMSG; return -1; } else if (block.get_len() > 0) { size_t use_bytes = block.get_len() - c_stream->avail_out; block.set_len(use_bytes); buffer->add_item(std::move(block)); *addon += use_bytes; } } if (deflateEnd(c_stream) != Z_OK) { delete c_stream; errno = EBADMSG; return -1; } if (block.get_len() > 0) { size_t use_bytes = block.get_len() - c_stream->avail_out; block.set_len(use_bytes); buffer->add_item(std::move(block)); *addon += use_bytes; } else { KafkaBlock *b = buffer->get_block_tail(); size_t use_bytes = b->get_len() - c_stream->avail_out; int remainer = b->get_len() - use_bytes; b->set_len(use_bytes); *addon += -remainer; } delete c_stream; break; case Kafka_Snappy: snappy_buffer = static_cast(env); { KafkaBuffer kafka_buffer_sink; KafkaSnappySource source(snappy_buffer); KafkaSnappySink sink(&kafka_buffer_sink); if (snappy::Compress(&source, &sink) < 0) { delete snappy_buffer; errno = EBADMSG; return -1; } size_t pre_n = buffer->get_size(); buffer->list_splice(&kafka_buffer_sink); *addon = buffer->get_size() - pre_n; } delete snappy_buffer; break; case Kafka_Lz4: lz4_cctx = static_cast(env); out_len = LZ4F_compressBound(0, &kPrefs); if (!block.allocate(out_len)) { LZ4F_freeCompressionContext(lz4_cctx); return -1; } lz4_r = LZ4F_compressEnd(lz4_cctx, block.get_block(), block.get_len(), NULL); if (LZ4F_isError(lz4_r)) { LZ4F_freeCompressionContext(lz4_cctx); errno = EBADMSG; return -1; } block.set_len(lz4_r); buffer->add_item(std::move(block)); *addon = lz4_r; LZ4F_freeCompressionContext(lz4_cctx); break; case Kafka_Zstd: zstd_cctx = static_cast(env); if (!block.allocate(zstd_end_bufsize)) return -1; out.dst = block.get_block(); out.pos = 0; out.size = 1024000; zstd_r = ZSTD_endStream(zstd_cctx, &out); if (ZSTD_isError(zstd_r) || zstd_r > 0) { ZSTD_freeCStream(zstd_cctx); errno = EBADMSG; return -1; } block.set_len(out.pos); buffer->add_item(std::move(block)); *addon = out.pos; ZSTD_freeCStream(zstd_cctx); break; default: return 0; } return 0; } KafkaRequest::KafkaRequest() { using namespace std::placeholders; this->encode_func_map[Kafka_Metadata] = std::bind(&KafkaRequest::encode_metadata, this, _1, _2); this->encode_func_map[Kafka_Produce] = std::bind(&KafkaRequest::encode_produce, this, _1, _2); this->encode_func_map[Kafka_Fetch] = std::bind(&KafkaRequest::encode_fetch, this, _1, _2); this->encode_func_map[Kafka_FindCoordinator] = std::bind(&KafkaRequest::encode_findcoordinator, this, _1, _2); this->encode_func_map[Kafka_JoinGroup] = std::bind(&KafkaRequest::encode_joingroup, this, _1, _2); this->encode_func_map[Kafka_SyncGroup] = std::bind(&KafkaRequest::encode_syncgroup, this, _1, _2); this->encode_func_map[Kafka_Heartbeat] = std::bind(&KafkaRequest::encode_heartbeat, this, _1, _2); this->encode_func_map[Kafka_OffsetFetch] = std::bind(&KafkaRequest::encode_offsetfetch, this, _1, _2); this->encode_func_map[Kafka_OffsetCommit] = std::bind(&KafkaRequest::encode_offsetcommit, this, _1, _2); this->encode_func_map[Kafka_ListOffsets] = std::bind(&KafkaRequest::encode_listoffset, this, _1, _2); this->encode_func_map[Kafka_LeaveGroup] = std::bind(&KafkaRequest::encode_leavegroup, this, _1, _2); this->encode_func_map[Kafka_ApiVersions] = std::bind(&KafkaRequest::encode_apiversions, this, _1, _2); this->encode_func_map[Kafka_SaslHandshake] = std::bind(&KafkaRequest::encode_saslhandshake, this, _1, _2); this->encode_func_map[Kafka_SaslAuthenticate] = std::bind(&KafkaRequest::encode_saslauthenticate, this, _1, _2); } int KafkaRequest::encode_produce(struct iovec vectors[], int max) { this->stream->reset(vectors, max); //transaction_id if (this->api_version >= 3) append_nullable_string(this->msgbuf, "", 0); append_i16(this->msgbuf, this->config.get_produce_acks()); append_i32(this->msgbuf, this->config.get_produce_timeout()); int topic_cnt = 0; this->toppar_list.rewind(); KafkaToppar *toppar; while ((toppar = this->toppar_list.get_next()) != NULL) { std::string topic_header; KafkaBlock header_block; int record_flag = -1; append_string(topic_header, toppar->get_topic()); append_i32(topic_header, 1); append_i32(topic_header, toppar->get_partition()); append_i32(topic_header, 0); // recordset length if (!header_block.set_block((void *)topic_header.c_str(), topic_header.size())) { return -1; } void *recordset_size_ptr = (void *)((char *)header_block.get_block() + header_block.get_len() - 4); int64_t first_timestamp = 0; int64_t max_timestamp = 0; const int MSGV2HSize = (8 + 4 + 4 + 1 + 4 + 2 + 4 + 8 + 8 + 8 + 2 + 4 + 4); int batch_length = 0; if (this->message_version == 2) batch_length = MSGV2HSize - (8 + 4); size_t cur_serialized_len = this->serialized.get_size(); int batch_cnt = 0; toppar->save_record_startpos(); KafkaRecord *record; while ((record = toppar->get_record_next()) != NULL) { KafkaBlock compress_block; KafkaBlock record_block; struct timespec ts; if (record->get_timestamp() == 0) { clock_gettime(CLOCK_REALTIME, &ts); record->set_timestamp((ts.tv_sec * 1000000000 + ts.tv_nsec) / 1000 / 1000); } if (batch_cnt == 0) { if (kafka_compress_prepare(this->config.get_compress_type(), &this->compress_env, &compress_block) < 0) { return -1; } first_timestamp = record->get_timestamp(); } int ret = append_record(&record_block, record, batch_cnt, this->message_version, this->config, first_timestamp, this->compress_env, batch_length); if (ret < 0) return -1; if (ret > 0) { toppar->record_rollback(); toppar->save_record_endpos(); if (record_flag < 0) { errno = EMSGSIZE; return -1; } else record_flag = 1; break; } if (batch_cnt == 0) { this->serialized.add_item(std::move(header_block)); cur_serialized_len = this->serialized.get_size(); this->serialized.set_insert_pos(); if (compress_block.get_len() > 0) this->serialized.add_item(std::move(compress_block)); } if (record_block.get_len() > 0) this->serialized.add_item(std::move(record_block)); record_flag = 0; toppar->save_record_endpos(); max_timestamp = record->get_timestamp(); ++batch_cnt; batch_length += this->serialized.get_size() - cur_serialized_len; cur_serialized_len = this->serialized.get_size(); } if (record_flag < 0) continue; if (this->message_version == 2) { if (this->config.get_compress_type() != Kafka_NoCompress) { int addon = 0; if (kafka_compress_finish(this->config.get_compress_type(), this->compress_env, &this->serialized, &addon) < 0) { return -1; } batch_length += addon; } std::string record_header; append_i64(record_header, 0); append_i32(record_header, batch_length); append_i32(record_header, 0); append_i8(record_header, 2); //magic uint32_t crc_32 = 0; size_t crc32_offset = record_header.size(); append_i32(record_header, crc_32); append_i16(record_header, this->config.get_compress_type()); append_i32(record_header, batch_cnt - 1); append_i64(record_header, first_timestamp); append_i64(record_header, max_timestamp); append_i64(record_header, -1); //produce_id append_i16(record_header, -1); append_i32(record_header, -1); append_i32(record_header, batch_cnt); KafkaBlock *header_block = new KafkaBlock; if (!header_block->set_block((void *)record_header.c_str(), record_header.size())) { delete header_block; return -1; } char *crc_ptr = (char *)header_block->get_block() + crc32_offset; this->serialized.insert_list(header_block); crc_32 = crc32c(crc_32, (const void *)(crc_ptr + 4), header_block->get_len() - crc32_offset - 4); this->serialized.block_insert_rewind(); KafkaBlock *block; while ((block = this->serialized.get_block_insert_next()) != NULL) crc_32 = crc32c(crc_32, block->get_block(), block->get_len()); *(uint32_t *)crc_ptr = htonl(crc_32); *(uint32_t *)recordset_size_ptr = htonl(batch_length + 4 + 8); } else { if (this->config.get_compress_type() != Kafka_NoCompress) { int addon = 0; if (kafka_compress_finish(this->config.get_compress_type(), this->compress_env, &this->serialized, &addon) < 0) { return -1; } batch_length += addon; int message_size = 4 + 1 + 1 + 4 + 4 + batch_length; if (this->message_version == 1) message_size += 8; std::string wrap_header; append_i64(wrap_header, 0); append_i32(wrap_header, message_size); int crc_32 = crc32(0, NULL, 0); size_t crc32_offset = wrap_header.size(); append_i32(wrap_header, crc_32); append_i8(wrap_header, this->message_version); append_i8(wrap_header, this->config.get_compress_type()); if (this->message_version == 1) append_i64(wrap_header, first_timestamp); append_bytes(wrap_header, ""); append_i32(wrap_header, batch_length); const char *crc_ptr = (const char *)wrap_header.c_str() + crc32_offset; crc_32 = crc32(crc_32, (Bytef *)(crc_ptr + 4), wrap_header.size() - crc32_offset - 4); this->serialized.block_insert_rewind(); KafkaBlock *block; while ((block = this->serialized.get_block_insert_next()) != NULL) crc_32 = crc32(crc_32, (Bytef *)block->get_block(), block->get_len()); *(uint32_t *)crc_ptr = htonl(crc_32); KafkaBlock *wrap_block = new KafkaBlock; if (!wrap_block->set_block((void *)wrap_header.c_str(), wrap_header.size())) { delete wrap_block; return -1; } this->serialized.insert_list(wrap_block); *(uint32_t *)recordset_size_ptr = htonl(message_size + 8 + 4); } else *(uint32_t *)recordset_size_ptr = htonl(batch_length); } ++topic_cnt; } append_i32(this->msgbuf, topic_cnt); this->cur_size += this->msgbuf.size(); this->stream->append_nocopy(this->msgbuf.c_str(), this->msgbuf.size()); vectors[0].iov_base = (void *)this->msgbuf.c_str(); vectors[0].iov_len = this->msgbuf.size(); KafkaBlock *block = this->serialized.get_block_first(); while (block) { this->stream->append_nocopy((const char *)block->get_block(), block->get_len()); this->cur_size += block->get_len(); block = this->serialized.get_block_next(); } return this->stream->size(); } int KafkaRequest::encode_fetch(struct iovec vectors[], int max) { append_i32(this->msgbuf, -1); append_i32(this->msgbuf, this->config.get_fetch_timeout()); append_i32(this->msgbuf, this->config.get_fetch_min_bytes()); if (this->api_version >= 3) append_i32(this->msgbuf, this->config.get_fetch_max_bytes()); //isolation_level if (this->api_version >= 4) append_i8(this->msgbuf, 0); if (this->api_version >= 7) { //sessionid append_i32(this->msgbuf, 0); //epoch append_i32(this->msgbuf, -1); } int topic_cnt_pos = this->msgbuf.size(); append_i32(this->msgbuf, 0); int topic_cnt = 0; this->toppar_list.rewind(); KafkaToppar *toppar; while ((toppar = this->toppar_list.get_next()) != NULL) { append_string(this->msgbuf, toppar->get_topic()); append_i32(this->msgbuf, 1); append_i32(this->msgbuf, toppar->get_partition()); //CurrentLeaderEpoch if (this->api_version >= 9) append_i32(this->msgbuf, -1); append_i64(this->msgbuf, toppar->get_offset()); //LogStartOffset if (this->api_version >= 5) append_i64(this->msgbuf, -1); append_i32(this->msgbuf, this->config.get_fetch_msg_max_bytes()); ++topic_cnt; } *(uint32_t *)(this->msgbuf.c_str() + topic_cnt_pos) = htonl(topic_cnt); //Length of the ForgottenTopics list if (this->api_version >= 7) append_i32(this->msgbuf, 0); //rackid if (this->api_version >= 11) { if (this->config.get_rack_id()) append_compact_string(this->msgbuf, this->config.get_rack_id()); else append_string(this->msgbuf, ""); } this->cur_size = this->msgbuf.size(); vectors[0].iov_base = (void *)this->msgbuf.c_str(); vectors[0].iov_len = this->msgbuf.size(); return 1; } int KafkaRequest::encode_metadata(struct iovec vectors[], int max) { int topic_cnt_pos = this->msgbuf.size(); if (this->api_version >= 1) append_i32(this->msgbuf, -1); else append_i32(this->msgbuf, 0); this->meta_list.rewind(); KafkaMeta *meta; int topic_cnt = 0; while ((meta = this->meta_list.get_next()) != NULL) { append_string(this->msgbuf, meta->get_topic()); ++topic_cnt; } if (this->api_version >= 4) { append_bool(this->msgbuf, this->config.get_allow_auto_topic_creation()); } *(uint32_t *)(this->msgbuf.c_str() + topic_cnt_pos) = htonl(topic_cnt); this->cur_size = this->msgbuf.size(); vectors[0].iov_base = (void *)this->msgbuf.c_str(); vectors[0].iov_len = this->msgbuf.size(); return 1; } int KafkaRequest::encode_findcoordinator(struct iovec vectors[], int max) { append_string(this->msgbuf, this->cgroup.get_group()); //coordinator key type if (this->api_version >= 1) append_i8(this->msgbuf, 0); this->cur_size = this->msgbuf.size(); vectors[0].iov_base = (void *)this->msgbuf.c_str(); vectors[0].iov_len = this->msgbuf.size(); return 1; } static std::string kafka_cgroup_gen_metadata(KafkaMetaList& meta_list) { std::string metadata; int meta_pos; int meta_cnt = 0; meta_list.rewind(); KafkaMeta *meta; append_i16(metadata, 2); // version meta_pos = metadata.size(); append_i32(metadata, 0); while ((meta = meta_list.get_next()) != NULL) { append_string(metadata, meta->get_topic()); meta_cnt++; } *(uint32_t *)(metadata.c_str() + meta_pos) = htonl(meta_cnt); //UserData empty append_bytes(metadata, ""); return metadata; } int KafkaRequest::encode_joingroup(struct iovec vectors[], int max) { append_string(this->msgbuf, this->cgroup.get_group()); append_i32(this->msgbuf, this->config.get_session_timeout()); if (this->api_version >= 1) append_i32(this->msgbuf, this->config.get_rebalance_timeout()); //member_id append_string(this->msgbuf, this->cgroup.get_member_id()); //group_instance_id if (this->api_version >= 5) append_nullable_string(this->msgbuf, "", 0); append_string(this->msgbuf, this->cgroup.get_protocol_type()); int protocol_pos = this->msgbuf.size(); append_i32(this->msgbuf, 0); int protocol_cnt = 0; struct list_head *pos; kafka_group_protocol_t *group_protocol; list_for_each(pos, this->cgroup.get_group_protocol()) { ++protocol_cnt; group_protocol = list_entry(pos, kafka_group_protocol_t, list); append_string(this->msgbuf, group_protocol->protocol_name); append_bytes(this->msgbuf, kafka_cgroup_gen_metadata(this->meta_list)); } *(uint32_t *)(this->msgbuf.c_str() + protocol_pos) = htonl(protocol_cnt); this->cur_size = this->msgbuf.size(); vectors[0].iov_base = (void *)this->msgbuf.c_str(); vectors[0].iov_len = this->msgbuf.size(); return 1; } std::string KafkaMessage::get_member_assignment(kafka_member_t *member) { std::string assignment; //version append_i16(assignment, 2); size_t topic_cnt_pos = assignment.size(); append_i32(assignment, 0); struct list_head *pos; KafkaToppar *toppar; int topic_cnt = 0; list_for_each(pos, &member->assigned_toppar_list) { toppar = list_entry(pos, KafkaToppar, list); append_string(assignment, toppar->get_topic()); append_i32(assignment, 1); append_i32(assignment, toppar->get_partition()); ++topic_cnt; } //userdata append_bytes(assignment, ""); *(uint32_t *)(assignment.c_str() + topic_cnt_pos) = htonl(topic_cnt); return assignment; } int KafkaRequest::encode_syncgroup(struct iovec vectors[], int max) { append_string(this->msgbuf, this->cgroup.get_group()); append_i32(this->msgbuf, this->cgroup.get_generation_id()); append_string(this->msgbuf, this->cgroup.get_member_id()); //group_instance_id if (this->api_version >= 3) append_nullable_string(this->msgbuf, "", 0); if (this->cgroup.is_leader()) { append_i32(this->msgbuf, this->cgroup.get_member_elements()); for (int i = 0; i < this->cgroup.get_member_elements(); ++i) { kafka_member_t *member = this->cgroup.get_members()[i]; append_string(this->msgbuf, member->member_id); append_bytes(this->msgbuf, std::move(get_member_assignment(member))); } } else append_i32(this->msgbuf, 0); this->cur_size = this->msgbuf.size(); vectors[0].iov_base = (void *)this->msgbuf.c_str(); vectors[0].iov_len = this->msgbuf.size(); return 1; } int KafkaRequest::encode_leavegroup(struct iovec vectors[], int max) { append_string(this->msgbuf, this->cgroup.get_group()); append_string(this->msgbuf, this->cgroup.get_member_id()); this->cur_size = this->msgbuf.size(); vectors[0].iov_base = (void *)this->msgbuf.c_str(); vectors[0].iov_len = this->msgbuf.size(); return 1; } int KafkaRequest::encode_listoffset(struct iovec vectors[], int max) { append_i32(this->msgbuf, -1); int topic_cnt = 0; int topic_cnt_pos = this->msgbuf.size(); append_i32(this->msgbuf, 0); struct list_head *pos; KafkaToppar *toppar; list_for_each(pos, this->toppar_list.get_head()) { toppar = this->toppar_list.get_entry(pos); append_string(this->msgbuf, toppar->get_topic()); append_i32(this->msgbuf, 1); append_i32(this->msgbuf, toppar->get_partition()); append_i64(this->msgbuf, toppar->get_offset_timestamp()); if (this->api_version == 0) append_i32(this->msgbuf, 1); ++topic_cnt; } *(uint32_t *)(this->msgbuf.c_str() + topic_cnt_pos) = htonl(topic_cnt); this->cur_size = this->msgbuf.size(); vectors[0].iov_base = (void *)this->msgbuf.c_str(); vectors[0].iov_len = this->msgbuf.size(); return 1; } int KafkaRequest::encode_offsetfetch(struct iovec vectors[], int max) { append_string(this->msgbuf, this->cgroup.get_group()); int topic_cnt = 0; int topic_cnt_pos = this->msgbuf.size(); append_i32(this->msgbuf, 0); this->cgroup.assigned_toppar_rewind(); KafkaToppar *toppar; while ((toppar = this->cgroup.get_assigned_toppar_next()) != NULL) { append_string(this->msgbuf, toppar->get_topic()); append_i32(this->msgbuf, 1); append_i32(this->msgbuf, toppar->get_partition()); ++topic_cnt; } *(uint32_t *)(this->msgbuf.c_str() + topic_cnt_pos) = htonl(topic_cnt); this->cur_size = this->msgbuf.size(); vectors[0].iov_base = (void *)this->msgbuf.c_str(); vectors[0].iov_len = this->msgbuf.size(); return 1; } int KafkaRequest::encode_offsetcommit(struct iovec vectors[], int max) { append_string(this->msgbuf, this->cgroup.get_group()); if (this->api_version >= 1) { append_i32(this->msgbuf, this->cgroup.get_generation_id()); append_string(this->msgbuf, this->cgroup.get_member_id()); } //GroupInstanceId if (this->api_version >= 7) append_nullable_string(this->msgbuf, "", 0); //RetentionTime if (this->api_version >= 2 && this->api_version <= 4) append_i64(this->msgbuf, -1); int toppar_cnt = 0; int toppar_cnt_pos = this->msgbuf.size(); append_i32(this->msgbuf, 0); this->toppar_list.rewind(); KafkaToppar *toppar; while ((toppar = this->toppar_list.get_next()) != NULL) { append_string(this->msgbuf, toppar->get_topic()); append_i32(this->msgbuf, 1); append_i32(this->msgbuf, toppar->get_partition()); append_i64(this->msgbuf, toppar->get_offset() + 1); if (this->api_version >= 6) append_i32(this->msgbuf, -1); if (this->api_version == 1) append_i64(this->msgbuf, -1); append_nullable_string(this->msgbuf, "", 0); ++toppar_cnt; } *(uint32_t *)(this->msgbuf.c_str() + toppar_cnt_pos) = htonl(toppar_cnt); this->cur_size = this->msgbuf.size(); vectors[0].iov_base = (void *)this->msgbuf.c_str(); vectors[0].iov_len = this->msgbuf.size(); return 1; } int KafkaRequest::encode_heartbeat(struct iovec vectors[], int max) { append_string(this->msgbuf, this->cgroup.get_group()); append_i32(this->msgbuf, this->cgroup.get_generation_id()); append_string(this->msgbuf, this->cgroup.get_member_id()); //group_instance_id if (this->api_version >= 3) append_nullable_string(this->msgbuf, "", 0); this->cur_size = this->msgbuf.size(); vectors[0].iov_base = (void *)this->msgbuf.c_str(); vectors[0].iov_len = this->msgbuf.size(); return 1; } int KafkaRequest::encode_apiversions(struct iovec vectors[], int max) { return 0; } int KafkaRequest::encode_saslhandshake(struct iovec vectors[], int max) { append_string(this->msgbuf, this->config.get_sasl_mech()); this->cur_size = this->msgbuf.size(); vectors[0].iov_base = (void *)this->msgbuf.c_str(); vectors[0].iov_len = this->msgbuf.size(); return 1; } int KafkaRequest::encode_saslauthenticate(struct iovec vectors[], int max) { append_bytes(this->msgbuf, this->sasl->buf, this->sasl->bsize); this->cur_size = this->msgbuf.size(); vectors[0].iov_base = (void *)this->msgbuf.c_str(); vectors[0].iov_len = this->msgbuf.size(); return 1; } KafkaResponse::KafkaResponse() { using namespace std::placeholders; this->parse_func_map[Kafka_Metadata] = std::bind(&KafkaResponse::parse_metadata, this, _1, _2); this->parse_func_map[Kafka_Produce] = std::bind(&KafkaResponse::parse_produce, this, _1, _2); this->parse_func_map[Kafka_Fetch] = std::bind(&KafkaResponse::parse_fetch, this, _1, _2); this->parse_func_map[Kafka_FindCoordinator] = std::bind(&KafkaResponse::parse_findcoordinator, this, _1, _2); this->parse_func_map[Kafka_JoinGroup] = std::bind(&KafkaResponse::parse_joingroup, this, _1, _2); this->parse_func_map[Kafka_SyncGroup] = std::bind(&KafkaResponse::parse_syncgroup, this, _1, _2); this->parse_func_map[Kafka_Heartbeat] = std::bind(&KafkaResponse::parse_heartbeat, this, _1, _2); this->parse_func_map[Kafka_OffsetFetch] = std::bind(&KafkaResponse::parse_offsetfetch, this, _1, _2); this->parse_func_map[Kafka_OffsetCommit] = std::bind(&KafkaResponse::parse_offsetcommit, this, _1, _2); this->parse_func_map[Kafka_ListOffsets] = std::bind(&KafkaResponse::parse_listoffset, this, _1, _2); this->parse_func_map[Kafka_LeaveGroup] = std::bind(&KafkaResponse::parse_leavegroup, this, _1, _2); this->parse_func_map[Kafka_ApiVersions] = std::bind(&KafkaResponse::parse_apiversions, this, _1, _2); this->parse_func_map[Kafka_SaslHandshake] = std::bind(&KafkaResponse::parse_saslhandshake, this, _1, _2); this->parse_func_map[Kafka_SaslAuthenticate] = std::bind(&KafkaResponse::parse_saslauthenticate, this, _1, _2); } int KafkaResponse::parse_response() { auto it = this->parse_func_map.find(this->api_type); if (it == this->parse_func_map.end()) { errno = EPROTO; return -1; } void *buf = this->parser->msgbuf; size_t size = this->parser->message_size; int32_t correlation_id; if (parse_i32(&buf, &size, &correlation_id) < 0) return -1; this->correlation_id = correlation_id; int ret = it->second(&buf, &size); if (ret < 0) return -1; if (size != 0) { errno = EBADMSG; return -1; } return ret; } static int kafka_meta_parse_broker(void **buf, size_t *size, int api_version, KafkaBrokerList *broker_list) { int32_t broker_cnt; CHECK_RET(parse_i32(buf, size, &broker_cnt)); if (broker_cnt < 0) { errno = EBADMSG; return -1; } for (int i = 0; i < broker_cnt; ++i) { KafkaBroker broker; kafka_broker_t *ptr = broker.get_raw_ptr(); CHECK_RET(parse_i32(buf, size, &ptr->node_id)); CHECK_RET(parse_string(buf, size, &ptr->host)); CHECK_RET(parse_i32(buf, size, &ptr->port)); if (api_version >= 1) CHECK_RET(parse_string(buf, size, &ptr->rack)); broker_list->rewind(); KafkaBroker *last; while ((last = broker_list->get_next()) != NULL) { if (last->get_node_id() == broker.get_node_id()) { broker_list->del_cur(); delete last; break; } } broker_list->add_item(std::move(broker)); } return 0; } static bool kafka_broker_get_leader(int leader_id, KafkaBrokerList *broker_list, kafka_broker_t *leader) { KafkaBroker *bbroker; broker_list->rewind(); while ((bbroker = broker_list->get_next()) != NULL) { if (bbroker->get_node_id() == leader_id) { kafka_broker_t *broker = bbroker->get_raw_ptr(); char *host = strdup(broker->host); if (host) { char *rack = NULL; if (broker->rack) rack = strdup(broker->rack); if (!broker->rack || rack) { kafka_broker_deinit(leader); *leader = *broker; leader->host = host; leader->rack = rack; return true; } free(host); } return false; } } errno = EBADMSG; return false; } static int kafka_meta_parse_partition(void **buf, size_t *size, KafkaMeta *meta, KafkaBrokerList *broker_list) { int32_t leader_id; int32_t replica_cnt, isr_cnt; int32_t partition_cnt; int32_t i, j; CHECK_RET(parse_i32(buf, size, &partition_cnt)); if (partition_cnt < 0) { errno = EBADMSG; return -1; } if (!meta->create_partitions(partition_cnt)) return -1; kafka_partition_t **partition = meta->get_partitions(); for (i = 0; i < partition_cnt; ++i) { int16_t error; int32_t index; if (parse_i16(buf, size, &error) < 0) break; partition[i]->error = error; if (parse_i32(buf, size, &index) < 0) break; partition[i]->partition_index = index; if (parse_i32(buf, size, &leader_id) < 0) break; if (!kafka_broker_get_leader(leader_id, broker_list, &partition[i]->leader)) break; if (parse_i32(buf, size, &replica_cnt) < 0) break; if (!meta->create_replica_nodes(i, replica_cnt)) break; for (j = 0; j < replica_cnt; ++j) { int32_t replica_node; if (parse_i32(buf, size, &replica_node) < 0) break; partition[i]->replica_nodes[j] = replica_node; } if (j != replica_cnt) break; if (parse_i32(buf, size, &isr_cnt) < 0) break; if (!meta->create_isr_nodes(i, isr_cnt)) break; for (j = 0; j < isr_cnt; ++j) { int32_t isr_node; if (parse_i32(buf, size, &isr_node) < 0) break; partition[i]->isr_nodes[j] = isr_node; } if (j != isr_cnt) break; } if (i != partition_cnt) return -1; return 0; } static KafkaMeta *find_meta_by_name(const std::string& topic, KafkaMetaList *meta_list) { meta_list->rewind(); KafkaMeta *meta; while ((meta = meta_list->get_next()) != NULL) { if (meta->get_topic() == topic) return meta; } errno = EBADMSG; return NULL; } static int kafka_meta_parse_topic(void **buf, size_t *size, int api_version, KafkaMetaList *meta_list, KafkaBrokerList *broker_list) { KafkaMetaList lst; int32_t topic_cnt; CHECK_RET(parse_i32(buf, size, &topic_cnt)); for (int32_t topic_idx = 0; topic_idx < topic_cnt; ++topic_idx) { int16_t error; CHECK_RET(parse_i16(buf, size, &error)); std::string topic_name; CHECK_RET(parse_string(buf, size, topic_name)); KafkaMeta *meta = find_meta_by_name(topic_name, meta_list); if (!meta) return -1; KafkaMeta new_mta; new_mta.set_topic(topic_name); kafka_meta_t *ptr = new_mta.get_raw_ptr(); ptr->error = error; if (api_version >= 1) CHECK_RET(parse_i8(buf, size, &ptr->is_internal)); CHECK_RET(kafka_meta_parse_partition(buf, size, &new_mta, broker_list)); lst.add_item(std::move(new_mta)); } *meta_list = std::move(lst); return 0; } int KafkaResponse::parse_metadata(void **buf, size_t *size) { int32_t throttle_time, controller_id; std::string cluster_id; if (this->api_version >= 3) CHECK_RET(parse_i32(buf, size, &throttle_time)); CHECK_RET(kafka_meta_parse_broker(buf, size, this->api_version, &this->broker_list)); if (this->api_version >= 2) CHECK_RET(parse_string(buf, size, cluster_id)); if (this->api_version >= 1) CHECK_RET(parse_i32(buf, size, &controller_id)); CHECK_RET(kafka_meta_parse_topic(buf, size, this->api_version, &this->meta_list, &this->broker_list)); return 0; } KafkaToppar *KafkaMessage::find_toppar_by_name(const std::string& topic, int partition, struct list_head *toppar_list) { KafkaToppar *toppar; struct list_head *pos; list_for_each(pos, toppar_list) { toppar = list_entry(pos, KafkaToppar, list); if (toppar->get_topic() == topic && toppar->get_partition() == partition) return toppar; } errno = EBADMSG; return NULL; } KafkaToppar *KafkaMessage::find_toppar_by_name(const std::string& topic, int partition, KafkaTopparList *toppar_list) { toppar_list->rewind(); KafkaToppar *toppar; while ((toppar = toppar_list->get_next()) != NULL) { if (toppar->get_topic() == topic && toppar->get_partition() == partition) return toppar; } errno = EBADMSG; return NULL; } int KafkaResponse::parse_produce(void **buf, size_t *size) { int32_t topic_cnt; std::string topic_name; int32_t partition_cnt; int32_t partition; int64_t base_offset, log_append_time, log_start_offset; int32_t throttle_time; int produce_timeout = this->config.get_produce_timeout() * 2; CHECK_RET(parse_i32(buf, size, &topic_cnt)); for (int32_t topic_idx = 0; topic_idx < topic_cnt; ++topic_idx) { CHECK_RET(parse_string(buf, size, topic_name)); CHECK_RET(parse_i32(buf, size, &partition_cnt)); for (int32_t i = 0; i < partition_cnt; ++i) { CHECK_RET(parse_i32(buf, size, &partition)); KafkaToppar *toppar = find_toppar_by_name(topic_name, partition, &this->toppar_list); if (!toppar) return -1; kafka_topic_partition_t *ptr = toppar->get_raw_ptr(); CHECK_RET(parse_i16(buf, size, &ptr->error)); log_append_time = -1; CHECK_RET(parse_i64(buf, size, &base_offset)); if (this->api_version >= 2) CHECK_RET(parse_i64(buf, size, &log_append_time)); if (this->api_version >=5) CHECK_RET(parse_i64(buf, size, &log_start_offset)); struct list_head *pos; KafkaRecord *record; if (ptr->error == KAFKA_REQUEST_TIMED_OUT) { toppar->restore_record_curpos(); this->config.set_produce_timeout(produce_timeout); continue; } for (pos = toppar->get_record_startpos()->next; pos != toppar->get_record_endpos(); pos = pos->next) { record = list_entry(pos, KafkaRecord, list); record->set_status(ptr->error); if (ptr->error) continue; record->set_offset(base_offset++); if (log_append_time != -1) record->set_timestamp(log_append_time); } } } if (this->api_version >= 1) CHECK_RET(parse_i32(buf, size, &throttle_time)); return 0; } int KafkaResponse::parse_fetch(void **buf, size_t *size) { int32_t throttle_time; this->toppar_list.rewind(); KafkaToppar *toppar; while ((toppar = this->toppar_list.get_next()) != NULL) toppar->clear_records(); if (this->api_version >= 1) CHECK_RET(parse_i32(buf, size, &throttle_time)); if (this->api_version >= 7) { int16_t error; int32_t sessionid; parse_i16(buf, size, &error); parse_i32(buf, size, &sessionid); } int32_t topic_cnt; std::string topic_name; int32_t partition_cnt; int32_t partition; int32_t aborted_cnt; int32_t preferred_read_replica; int64_t producer_id, first_offset; int64_t high_watermark; CHECK_RET(parse_i32(buf, size, &topic_cnt)); for (int32_t topic_idx = 0; topic_idx < topic_cnt; ++topic_idx) { CHECK_RET(parse_string(buf, size, topic_name)); CHECK_RET(parse_i32(buf, size, &partition_cnt)); for (int i = 0; i < partition_cnt; ++i) { CHECK_RET(parse_i32(buf, size, &partition)); KafkaToppar *toppar = find_toppar_by_name(topic_name, partition, &this->toppar_list); if (!toppar) return -1; kafka_topic_partition_t *ptr = toppar->get_raw_ptr(); CHECK_RET(parse_i16(buf, size, &ptr->error)); CHECK_RET(parse_i64(buf, size, &high_watermark)); if (high_watermark > ptr->low_watermark) ptr->high_watermark = high_watermark; if (this->api_version >= 4) { CHECK_RET(parse_i64(buf, size, (int64_t *)&ptr->last_stable_offset)); if (this->api_version >= 5) CHECK_RET(parse_i64(buf, size, (int64_t *)&ptr->log_start_offset)); CHECK_RET(parse_i32(buf, size, &aborted_cnt)); for (int32_t j = 0; j < aborted_cnt; ++j) { CHECK_RET(parse_i64(buf, size, &producer_id)); CHECK_RET(parse_i64(buf, size, &first_offset)); } } if (this->api_version >= 11) { CHECK_RET(parse_i32(buf, size, &preferred_read_replica)); ptr->preferred_read_replica = preferred_read_replica; } if (parse_records(buf, size, this->config.get_check_crcs(), &this->uncompressed, toppar) != 0) { ptr->error = KAFKA_CORRUPT_MESSAGE; return -1; } } } return 0; } int KafkaResponse::parse_listoffset(void **buf, size_t *size) { int32_t throttle_time; int32_t topic_cnt; std::string topic_name; int32_t partition_cnt; int32_t partition; int64_t offset_timestamp, offset; int32_t offset_cnt; if (this->api_version >= 2) CHECK_RET(parse_i32(buf, size, &throttle_time)); CHECK_RET(parse_i32(buf, size, &topic_cnt)); for (int32_t topic_idx = 0; topic_idx < topic_cnt; ++topic_idx) { CHECK_RET(parse_string(buf, size, topic_name)); CHECK_RET(parse_i32(buf, size, &partition_cnt)); for (int32_t i = 0; i < partition_cnt; ++i) { CHECK_RET(parse_i32(buf, size, &partition)); KafkaToppar *toppar = find_toppar_by_name(topic_name, partition, &this->toppar_list); if (!toppar) return -1; kafka_topic_partition_t *ptr = toppar->get_raw_ptr(); CHECK_RET(parse_i16(buf, size, &ptr->error)); if (this->api_version == 1) { CHECK_RET(parse_i64(buf, size, &offset_timestamp)); CHECK_RET(parse_i64(buf, size, &offset)); if (ptr->offset_timestamp == -1) ptr->high_watermark = offset; else if (ptr->offset_timestamp == -2) ptr->low_watermark = offset; else ptr->offset = offset; } else if (this->api_version == 0) { CHECK_RET(parse_i32(buf, size, &offset_cnt)); for (int32_t j = 0; j < offset_cnt; ++j) { CHECK_RET(parse_i64(buf, size, &offset)); ptr->offset = offset; } ptr->low_watermark = 0; } } } return 0; } int KafkaResponse::parse_findcoordinator(void **buf, size_t *size) { int32_t throttle_time; if (this->api_version >= 1) CHECK_RET(parse_i32(buf, size, &throttle_time)); kafka_cgroup_t *cgroup = this->cgroup.get_raw_ptr(); CHECK_RET(parse_i16(buf, size, &cgroup->error)); if (this->api_version >= 1) CHECK_RET(parse_string(buf, size, &cgroup->error_msg)); CHECK_RET(parse_i32(buf, size, &cgroup->coordinator.node_id)); CHECK_RET(parse_string(buf, size, &cgroup->coordinator.host)); CHECK_RET(parse_i32(buf, size, &cgroup->coordinator.port)); return 0; } static bool kafka_meta_find_or_add_topic(const std::string& topic_name, KafkaMetaList *meta_list) { meta_list->rewind(); bool find = false; KafkaMeta *meta; while ((meta = meta_list->get_next()) != NULL) { if (topic_name == meta->get_topic()) { find = true; break; } } if (!find) { KafkaMeta tmp; if (!tmp.set_topic(topic_name)) return false; meta_list->add_item(tmp); } return true; } static int kafka_cgroup_parse_member(void **buf, size_t *size, KafkaCgroup *cgroup, KafkaMetaList *meta_list, int api_version) { int32_t member_cnt = 0; CHECK_RET(parse_i32(buf, size, &member_cnt)); if (member_cnt < 0) { errno = EBADMSG; return -1; } if (!cgroup->create_members(member_cnt)) return -1; kafka_member_t **member = cgroup->get_members(); int32_t i; for (i = 0; i < member_cnt; ++i) { if (parse_string(buf, size, &member[i]->member_id) < 0) break; if (api_version >= 5) { std::string group_instance_id; parse_string(buf, size, group_instance_id); } if (parse_bytes(buf, size, &member[i]->member_metadata, &member[i]->member_metadata_len) < 0) break; void *metadata = member[i]->member_metadata; size_t metadata_len = member[i]->member_metadata_len; int16_t version; int32_t topic_cnt; std::string topic_name; int32_t j; if (parse_i16(&metadata, &metadata_len, &version) < 0) break; if (parse_i32(&metadata, &metadata_len, &topic_cnt) < 0) break; for (j = 0; j < topic_cnt; ++j) { if (parse_string(&metadata, &metadata_len, topic_name) < 0) break; KafkaToppar * toppar = new KafkaToppar; if (!toppar->set_topic(topic_name.c_str())) { delete toppar; break; } list_add_tail(toppar->get_list(), &member[i]->toppar_list); if (!kafka_meta_find_or_add_topic(topic_name, meta_list)) return -1; } if (j != topic_cnt) break; } if (i != member_cnt) return -1; return 0; } int KafkaResponse::parse_joingroup(void **buf, size_t *size) { int32_t throttle_time; if (this->api_version >= 2) CHECK_RET(parse_i32(buf, size, &throttle_time)); kafka_cgroup_t *cgroup = this->cgroup.get_raw_ptr(); CHECK_RET(parse_i16(buf, size, &cgroup->error)); CHECK_RET(parse_i32(buf, size, &cgroup->generation_id)); CHECK_RET(parse_string(buf, size, &cgroup->protocol_name)); CHECK_RET(parse_string(buf, size, &cgroup->leader_id)); CHECK_RET(parse_string(buf, size, &cgroup->member_id)); CHECK_RET(kafka_cgroup_parse_member(buf, size, &this->cgroup, &this->meta_list, this->api_version)); return 0; } int KafkaMessage::kafka_parse_member_assignment(const char *bbuf, size_t n, KafkaCgroup *cgroup) { void **buf = (void **)&bbuf; size_t *size = &n; int32_t topic_cnt; int32_t partition_cnt; int16_t version; struct list_head *pos, *tmp; std::string topic_name; int32_t partition; list_for_each_safe(pos, tmp, cgroup->get_assigned_toppar_list()) { KafkaToppar *toppar = list_entry(pos, KafkaToppar, list); list_del(pos); delete toppar; } CHECK_RET(parse_i16(buf, size, &version)); CHECK_RET(parse_i32(buf, size, &topic_cnt)); for (int32_t i = 0; i < topic_cnt; ++i) { CHECK_RET(parse_string(buf, size, topic_name)); CHECK_RET(parse_i32(buf, size, &partition_cnt)); for (int32_t j = 0; j < partition_cnt; ++j) { CHECK_RET(parse_i32(buf, size, &partition)); KafkaToppar *toppar = new KafkaToppar; if (!toppar->set_topic_partition(topic_name, partition)) { delete toppar; return -1; } cgroup->add_assigned_toppar(toppar); } } return 0; } int KafkaResponse::parse_syncgroup(void **buf, size_t *size) { int32_t throttle_time; int16_t error; std::string member_assignment; if (this->api_version >= 1) CHECK_RET(parse_i32(buf, size, &throttle_time)); CHECK_RET(parse_i16(buf, size, &error)); this->cgroup.set_error(error); CHECK_RET(parse_bytes(buf, size, member_assignment)); if (!member_assignment.empty()) { CHECK_RET(kafka_parse_member_assignment(member_assignment.c_str(), member_assignment.size(), &this->cgroup)); } return 0; } int KafkaResponse::parse_leavegroup(void **buf, size_t *size) { int32_t throttle_time; int16_t error; if (this->api_version >= 1) CHECK_RET(parse_i32(buf, size, &throttle_time)); CHECK_RET(parse_i16(buf, size, &error)); this->cgroup.set_error(error); return 0; } int KafkaResponse::parse_offsetfetch(void **buf, size_t *size) { int32_t topic_cnt; std::string topic_name; int32_t partition_cnt; int32_t partition; CHECK_RET(parse_i32(buf, size, &topic_cnt)); for (int32_t topic_idx = 0; topic_idx < topic_cnt; ++topic_idx) { CHECK_RET(parse_string(buf, size, topic_name)); CHECK_RET(parse_i32(buf, size, &partition_cnt)); for (int32_t i = 0; i < partition_cnt; ++i) { CHECK_RET(parse_i32(buf, size, &partition)); KafkaToppar *toppar = find_toppar_by_name(topic_name, partition, this->cgroup.get_assigned_toppar_list()); if (!toppar) return -1; kafka_topic_partition_t *ptr = toppar->get_raw_ptr(); int64_t offset; CHECK_RET(parse_i64(buf, size, &offset)); if (this->config.get_offset_store() != KAFKA_OFFSET_ASSIGN) ptr->offset = offset; CHECK_RET(parse_string(buf, size, &ptr->committed_metadata)); CHECK_RET(parse_i16(buf, size, &ptr->error)); } } return 0; } int KafkaResponse::parse_offsetcommit(void **buf, size_t *size) { int32_t throttle_time; int32_t topic_cnt; std::string topic_name; int32_t partition_cnt; int32_t partition; if (this->api_version >= 3) CHECK_RET(parse_i32(buf, size, &throttle_time)); CHECK_RET(parse_i32(buf, size, &topic_cnt)); for (int32_t topic_idx = 0; topic_idx < topic_cnt; ++topic_idx) { CHECK_RET(parse_string(buf, size, topic_name)); CHECK_RET(parse_i32(buf, size, &partition_cnt)); for (int32_t i = 0 ; i < partition_cnt; ++i) { CHECK_RET(parse_i32(buf, size, &partition)); CHECK_RET(parse_i16(buf, size, &this->cgroup.get_raw_ptr()->error)); } } return 0; } int KafkaResponse::parse_heartbeat(void **buf, size_t *size) { int32_t throttle_time; int16_t error; if (this->api_version >= 1) CHECK_RET(parse_i32(buf, size, &throttle_time)); CHECK_RET(parse_i16(buf, size, &error)); this->cgroup.set_error(error); return 0; } static bool kafka_api_version_cmp(const kafka_api_version_t& api_ver1, const kafka_api_version_t& api_ver2) { return api_ver1.api_key < api_ver2.api_key; } int KafkaResponse::parse_apiversions(void **buf, size_t *size) { int16_t error; int32_t api_cnt; int32_t throttle_time; CHECK_RET(parse_i16(buf, size, &error)); CHECK_RET(parse_i32(buf, size, &api_cnt)); if (api_cnt < 0) { errno = EBADMSG; return -1; } void *p = malloc(api_cnt * sizeof(kafka_api_version_t)); if (!p) return -1; this->api->api = (kafka_api_version_t *)p; this->api->elements = api_cnt; for (int32_t i = 0; i < api_cnt; ++i) { CHECK_RET(parse_i16(buf, size, &this->api->api[i].api_key)); CHECK_RET(parse_i16(buf, size, &this->api->api[i].min_ver)); CHECK_RET(parse_i16(buf, size, &this->api->api[i].max_ver)); } if (this->api_version >= 1) CHECK_RET(parse_i32(buf, size, &throttle_time)); std::sort(this->api->api, this->api->api + api_cnt, kafka_api_version_cmp); this->api->features = kafka_get_features(this->api->api, api_cnt); return 0; } int KafkaResponse::parse_saslhandshake(void **buf, size_t *size) { std::string mechanism; int16_t error = 0; int32_t cnt, i; CHECK_RET(parse_i16(buf, size, &error)); if (error != 0) { this->broker.get_raw_ptr()->error = error; return 1; } CHECK_RET(parse_i32(buf, size, &cnt)); for (i = 0; i < cnt; i++) { CHECK_RET(parse_string(buf, size, mechanism)); if (strcasecmp(mechanism.c_str(), this->config.get_sasl_mech()) == 0) break; } if (i == cnt) { this->broker.get_raw_ptr()->error = KAFKA_SASL_AUTHENTICATION_FAILED; return 1; } for (i++; i < cnt; i++) CHECK_RET(parse_string(buf, size, mechanism)); errno = 0; if (!this->config.new_client(this->sasl)) { if (errno) return -1; this->broker.get_raw_ptr()->error = KAFKA_SASL_AUTHENTICATION_FAILED; return 1; } return 0; } int KafkaResponse::parse_saslauthenticate(void **buf, size_t *size) { std::string error_message; std::string auth_bytes; int16_t error = 0; CHECK_RET(parse_i16(buf, size, &error)); CHECK_RET(parse_string(buf, size, error_message)); CHECK_RET(parse_bytes(buf, size, auth_bytes)); if (error != 0) { this->broker.get_raw_ptr()->error = error; return 1; } errno = 0; if (this->config.get_raw_ptr()->recv(auth_bytes.c_str(), auth_bytes.size(), this->config.get_raw_ptr(), this->sasl) != 0) { if (errno) return -1; this->broker.get_raw_ptr()->error = KAFKA_SASL_AUTHENTICATION_FAILED; return 1; } return 0; } int KafkaResponse::handle_sasl_continue() { struct iovec iovecs[64]; int ret; int cnt = this->encode(iovecs, 64); if ((unsigned int)cnt > 64) { if (cnt > 64) errno = EOVERFLOW; return -1; } for (int i = 0; i < cnt; i++) { ret = this->feedback(iovecs[i].iov_base, iovecs[i].iov_len); if (ret != (int)iovecs[i].iov_len) { if (ret >= 0) errno = ENOBUFS; return -1; } } return 0; } int KafkaResponse::append(const void *buf, size_t *size) { int ret = KafkaMessage::append(buf, size); if (ret <= 0) return ret; ret = this->parse_response(); if (ret != 0) return ret; if (this->api_type == Kafka_SaslHandshake) { this->api_type = Kafka_SaslAuthenticate; this->clear_buf(); return this->handle_sasl_continue(); } else if (this->api_type == Kafka_SaslAuthenticate) { if (strncasecmp(this->config.get_sasl_mech(), "SCRAM", 5) == 0) { this->clear_buf(); if (this->sasl->scram.state != KAFKA_SASL_SCRAM_STATE_CLIENT_FINISHED) return this->handle_sasl_continue(); else this->sasl->status = 1; } } return 1; } } workflow-0.11.8/src/protocol/KafkaMessage.h000066400000000000000000000152151476003635400206120ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Wang Zhulei(wangzhulei@sogou-inc.com) */ #ifndef _KAFKAMESSAGE_H_ #define _KAFKAMESSAGE_H_ #include #include #include #include #include #include #include "kafka_parser.h" #include "ProtocolMessage.h" #include "KafkaDataTypes.h" namespace protocol { class KafkaMessage : public ProtocolMessage { public: KafkaMessage(); virtual ~KafkaMessage(); protected: virtual int encode(struct iovec vectors[], int max); virtual int append(const void *buf, size_t *size); private: int encode_head(); public: KafkaMessage(KafkaMessage&& msg); KafkaMessage& operator= (KafkaMessage&& msg); public: int encode_message(int api_type, struct iovec vectors[], int max); void set_api_type(int api_type) { this->api_type = api_type; } int get_api_type() const { return this->api_type; } void set_api_version(int ver) { this->api_version = ver; } int get_api_version() const { return this->api_version; } void set_correlation_id(int id) { this->correlation_id = id; } int get_correlation_id() const { return this->correlation_id; } void set_config(const KafkaConfig& conf) { this->config = conf; } const KafkaConfig *get_config() const { return &this->config; } void set_cgroup(const KafkaCgroup& cgroup) { this->cgroup = cgroup; } KafkaCgroup *get_cgroup() { return &this->cgroup; } void set_broker(const KafkaBroker& broker) { this->broker = broker; } KafkaBroker *get_broker() { return &this->broker; } void set_meta_list(const KafkaMetaList& meta_list) { this->meta_list = meta_list; } KafkaMetaList *get_meta_list() { return &this->meta_list; } void set_toppar_list(const KafkaTopparList& toppar_list) { this->toppar_list = toppar_list; } KafkaTopparList *get_toppar_list() { return &this->toppar_list; } void set_broker_list(const KafkaBrokerList& broker_list) { this->broker_list = broker_list; } KafkaBrokerList *get_broker_list() { return &this->broker_list; } void set_sasl(kafka_sasl_t *sasl) { this->sasl = sasl; } void set_api(kafka_api_t *api) { this->api = api; } void duplicate(const KafkaMessage& msg) { this->config = msg.config; this->cgroup = msg.cgroup; this->broker = msg.broker; this->meta_list = msg.meta_list; this->broker_list = msg.broker_list; this->toppar_list = msg.toppar_list; this->sasl = msg.sasl; this->api = msg.api; } void clear_buf() { this->msgbuf.clear(); this->headbuf.clear(); kafka_parser_deinit(this->parser); kafka_parser_init(this->parser); this->cur_size = 0; this->serialized = KafkaBuffer(); this->uncompressed = KafkaBuffer(); } protected: static int parse_message_set(void **buf, size_t *size, bool check_crcs, int msg_vers, struct list_head *record_list, KafkaBuffer *uncompressed, KafkaToppar *toppar); static int parse_message_record(void **buf, size_t *size, kafka_record_t *kafka_record); static int parse_record_batch(void **buf, size_t *size, bool check_crcs, struct list_head *record_list, KafkaBuffer *uncompressed, KafkaToppar *toppar); static int parse_records(void **buf, size_t *size, bool check_crcs, KafkaBuffer *uncompressed, KafkaToppar *toppar); static std::string get_member_assignment(kafka_member_t *member); static KafkaToppar *find_toppar_by_name(const std::string& topic, int partition, struct list_head *toppar_list); static KafkaToppar *find_toppar_by_name(const std::string& topic, int partition, KafkaTopparList *toppar_list); static int kafka_parse_member_assignment(const char *bbuf, size_t n, KafkaCgroup *cgroup); protected: kafka_parser_t *parser; using encode_func = std::function; std::map encode_func_map; using parse_func = std::function; std::map parse_func_map; class EncodeStream *stream; std::string msgbuf; std::string headbuf; KafkaConfig config; KafkaCgroup cgroup; KafkaBroker broker; KafkaMetaList meta_list; KafkaBrokerList broker_list; KafkaTopparList toppar_list; KafkaBuffer serialized; KafkaBuffer uncompressed; int api_type; int api_version; int correlation_id; int message_version; void *compress_env; size_t cur_size; kafka_sasl_t *sasl; kafka_api_t *api; }; class KafkaRequest : public KafkaMessage { public: KafkaRequest(); private: int encode_produce(struct iovec vectors[], int max); int encode_fetch(struct iovec vectors[], int max); int encode_metadata(struct iovec vectors[], int max); int encode_findcoordinator(struct iovec vectors[], int max); int encode_listoffset(struct iovec vectors[], int max); int encode_joingroup(struct iovec vectors[], int max); int encode_syncgroup(struct iovec vectors[], int max); int encode_leavegroup(struct iovec vectors[], int max); int encode_heartbeat(struct iovec vectors[], int max); int encode_offsetcommit(struct iovec vectors[], int max); int encode_offsetfetch(struct iovec vectors[], int max); int encode_apiversions(struct iovec vectors[], int max); int encode_saslhandshake(struct iovec vectors[], int max); int encode_saslauthenticate(struct iovec vectors[], int max); }; class KafkaResponse : public KafkaRequest { public: KafkaResponse(); int parse_response(); protected: virtual int append(const void *buf, size_t *size); private: int parse_produce(void **buf, size_t *size); int parse_fetch(void **buf, size_t *size); int parse_metadata(void **buf, size_t *size); int parse_findcoordinator(void **buf, size_t *size); int parse_joingroup(void **buf, size_t *size); int parse_syncgroup(void **buf, size_t *size); int parse_leavegroup(void **buf, size_t *size); int parse_listoffset(void **buf, size_t *size); int parse_offsetcommit(void **buf, size_t *size); int parse_offsetfetch(void **buf, size_t *size); int parse_heartbeat(void **buf, size_t *size); int parse_apiversions(void **buf, size_t *size); int parse_saslhandshake(void **buf, size_t *size); int parse_saslauthenticate(void **buf, size_t *size); int handle_sasl_continue(); }; } #endif workflow-0.11.8/src/protocol/KafkaResult.cc000066400000000000000000000047031476003635400206420ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wang Zhulei (wangzhulei@sogou-inc.com) */ #include "KafkaResult.h" namespace protocol { enum { KAFKA_STATUS_GET_RESULT, KAFKA_STATUS_END, }; KafkaResult::KafkaResult() { this->resp_vec = NULL; this->resp_num = 0; } KafkaResult& KafkaResult::operator= (KafkaResult&& move) { if (this != &move) { delete []this->resp_vec; this->resp_vec = move.resp_vec; move.resp_vec = NULL; this->resp_num = move.resp_num; move.resp_num = 0; } return *this; } KafkaResult::KafkaResult(KafkaResult&& move) { this->resp_vec = move.resp_vec; move.resp_vec = NULL; this->resp_num = move.resp_num; move.resp_num = 0; } void KafkaResult::create(size_t n) { delete []this->resp_vec; this->resp_vec = new KafkaResponse[n]; this->resp_num = n; } void KafkaResult::set_resp(KafkaResponse&& resp, size_t i) { assert(i < this->resp_num); this->resp_vec[i] = std::move(resp); } void KafkaResult::fetch_toppars(std::vector& toppars) { toppars.clear(); KafkaToppar *toppar = NULL; for (size_t i = 0; i < this->resp_num; ++i) { this->resp_vec[i].get_toppar_list()->rewind(); while ((toppar = this->resp_vec[i].get_toppar_list()->get_next()) != NULL) toppars.push_back(toppar); } } void KafkaResult::fetch_records(std::vector>& records) { records.clear(); KafkaToppar *toppar = NULL; KafkaRecord *record = NULL; for (size_t i = 0; i < this->resp_num; ++i) { if (this->resp_vec[i].get_api_type() != Kafka_Produce && this->resp_vec[i].get_api_type() != Kafka_Fetch) continue; this->resp_vec[i].get_toppar_list()->rewind(); while ((toppar = this->resp_vec[i].get_toppar_list()->get_next()) != NULL) { std::vector tmp; toppar->record_rewind(); while ((record = toppar->get_record_next()) != NULL) tmp.push_back(record); if (!tmp.empty()) records.emplace_back(std::move(tmp)); } } } } workflow-0.11.8/src/protocol/KafkaResult.h000066400000000000000000000024611476003635400205030ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wang Zhulei (wangzhulei@sogou-inc.com) */ #ifndef _KAFKARESULT_H_ #define _KAFKARESULT_H_ #include #include #include #include "KafkaMessage.h" #include "KafkaDataTypes.h" namespace protocol { class KafkaResult { public: // for offsetcommit void fetch_toppars(std::vector& toppars); // for produce, fetch void fetch_records(std::vector>& records); public: void create(size_t n); void set_resp(KafkaResponse&& resp, size_t i); public: KafkaResult(); virtual ~KafkaResult() { delete []this->resp_vec; } KafkaResult& operator= (KafkaResult&& move); KafkaResult(KafkaResult&& move); private: KafkaResponse *resp_vec; size_t resp_num; }; } #endif workflow-0.11.8/src/protocol/MySQLMessage.cc000066400000000000000000000353431476003635400207040ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com) */ #include #include #include #include #include #include #include #include #include #include "SSLWrapper.h" #include "mysql_byteorder.h" #include "mysql_types.h" #include "MySQLResult.h" #include "MySQLMessage.h" namespace protocol { #define MYSQL_PAYLOAD_MAX ((1 << 24) - 1) #define MYSQL_NATIVE_PASSWORD "mysql_native_password" #define CACHING_SHA2_PASSWORD "caching_sha2_password" #define MYSQL_CLEAR_PASSWORD "mysql_clear_password" MySQLMessage::~MySQLMessage() { if (parser_) { mysql_parser_deinit(parser_); mysql_stream_deinit(stream_); delete parser_; delete stream_; } } MySQLMessage::MySQLMessage(MySQLMessage&& move) : ProtocolMessage(std::move(move)) { parser_ = move.parser_; stream_ = move.stream_; seqid_ = move.seqid_; cur_size_ = move.cur_size_; move.parser_ = NULL; move.stream_ = NULL; move.seqid_ = 0; move.cur_size_ = 0; } MySQLMessage& MySQLMessage::operator= (MySQLMessage&& move) { if (this != &move) { *(ProtocolMessage *)this = std::move(move); if (parser_) { mysql_parser_deinit(parser_); mysql_stream_deinit(stream_); delete parser_; delete stream_; } parser_ = move.parser_; stream_ = move.stream_; seqid_ = move.seqid_; cur_size_ = move.cur_size_; move.parser_ = NULL; move.stream_ = NULL; move.seqid_ = 0; move.cur_size_ = 0; } return *this; } int MySQLMessage::append(const void *buf, size_t *size) { const void *stream_buf; size_t stream_len; size_t nleft = *size; size_t n; int ret; cur_size_ += *size; if (cur_size_ > this->size_limit) { errno = EMSGSIZE; return -1; } while (nleft > 0) { n = nleft; ret = mysql_stream_write(buf, &n, stream_); if (ret > 0) { seqid_ = mysql_stream_get_seq(stream_); mysql_stream_get_buf(&stream_buf, &stream_len, stream_); ret = decode_packet((const unsigned char *)stream_buf, stream_len); if (ret == -2) errno = EBADMSG; } if (ret < 0) return -1; nleft -= n; buf = (const char *)buf + n; } return ret; } int MySQLMessage::encode(struct iovec vectors[], int max) { const unsigned char *p = (unsigned char *)buf_.c_str(); size_t nleft = buf_.size(); uint8_t seqid_start = seqid_; uint8_t seqid = seqid_; unsigned char *head; uint32_t length; int i = 0; do { length = (nleft >= MYSQL_PAYLOAD_MAX ? MYSQL_PAYLOAD_MAX : (uint32_t)nleft); head = heads_[seqid]; int3store(head, length); head[3] = seqid++; vectors[i].iov_base = head; vectors[i].iov_len = 4; i++; vectors[i].iov_base = const_cast(p); vectors[i].iov_len = length; i++; if (i > max)//overflow break; if (nleft < MYSQL_PAYLOAD_MAX) return i; nleft -= MYSQL_PAYLOAD_MAX; p += length; } while (seqid != seqid_start); errno = EOVERFLOW; return -1; } void MySQLRequest::set_query(const char *query, size_t length) { set_command(MYSQL_COM_QUERY); buf_.resize(length + 1); char *buffer = const_cast(buf_.c_str()); buffer[0] = MYSQL_COM_QUERY; if (length > 0) memcpy(buffer + 1, query, length); } std::string MySQLRequest::get_query() const { size_t len = buf_.size(); if (len <= 1 || buf_[0] != MYSQL_COM_QUERY) return ""; return std::string(buf_.c_str() + 1); } #define MYSQL_CAPFLAG_CLIENT_SSL 0x00000800 #define MYSQL_CAPFLAG_CLIENT_PROTOCOL_41 0x00000200 #define MYSQL_CAPFLAG_CLIENT_SECURE_CONNECTION 0x00008000 #define MYSQL_CAPFLAG_CLIENT_CONNECT_WITH_DB 0x00000008 #define MYSQL_CAPFLAG_CLIENT_MULTI_STATEMENTS 0x00010000 #define MYSQL_CAPFLAG_CLIENT_MULTI_RESULTS 0x00020000 #define MYSQL_CAPFLAG_CLIENT_PS_MULTI_RESULTS 0x00040000 #define MYSQL_CAPFLAG_CLIENT_PLUGIN_AUTH 0x00080000 #define MYSQL_CAPFLAG_CLIENT_LOCAL_FILES 0x00000080 int MySQLHandshakeResponse::encode(struct iovec vectors[], int max) { const char empty[10] = {0}; uint16_t cap_flags_lower = capability_flags_ & 0xffffffff; uint16_t cap_flags_upper = capability_flags_ >> 16; buf_.clear(); buf_.append((const char *)&protocol_version_, 1); buf_.append(server_version_.c_str(), server_version_.size() + 1); buf_.append((const char *)&connection_id_, 4); buf_.append((const char *)auth_plugin_data_, 8); buf_.append(empty, 1); buf_.append((const char *)&cap_flags_lower, 2); buf_.append((const char *)&character_set_, 1); buf_.append((const char *)&status_flags_, 2); buf_.append((const char *)&cap_flags_upper, 2); buf_.push_back(21); buf_.append(empty, 10); buf_.append((const char *)auth_plugin_data_ + 8, 12); buf_.push_back(0); if (capability_flags_ & MYSQL_CAPFLAG_CLIENT_PLUGIN_AUTH) buf_.append(MYSQL_NATIVE_PASSWORD, strlen(MYSQL_NATIVE_PASSWORD) + 1); return MySQLMessage::encode(vectors, max); } int MySQLHandshakeResponse::decode_packet(const unsigned char *buf, size_t buflen) { const unsigned char *end = buf + buflen; const unsigned char *pos; uint16_t cap_flags_lower; uint16_t cap_flags_upper; if (buflen == 0) return -2; protocol_version_ = *buf; if (protocol_version_ == 255) { if (buflen >= 4) { const_cast(buf)[3] = '#'; if (mysql_parser_parse(buf, buflen, parser_) == 1) { disallowed_ = true; return 1; } } errno = EBADMSG; return -1; } pos = ++buf; while (pos < end && *pos) pos++; if (pos >= end || end - pos < 45) return -2; server_version_.assign((const char *)buf, pos - buf); buf = pos + 1; connection_id_ = uint4korr(buf); buf += 4; memcpy(auth_plugin_data_, buf, 8); buf += 9; cap_flags_lower = uint2korr(buf); buf += 2; character_set_ = *buf++; status_flags_ = uint2korr(buf); buf += 2; cap_flags_upper = uint2korr(buf); buf += 2; capability_flags_ = (cap_flags_upper << 16U) + cap_flags_lower; auth_plugin_data_len_ = *buf++; // 10 bytes reserved. All 0s. buf += 10; // auth_plugin_data always 20 bytes if (auth_plugin_data_len_ > 21) return -2; memcpy(auth_plugin_data_ + 8, buf, 12); buf += 13; if (capability_flags_ & MYSQL_CAPFLAG_CLIENT_PLUGIN_AUTH) { if (buf == end || *(end - 1) != '\0') return -2; auth_plugin_name_.assign((const char *)buf, end - 1 - buf); } return 1; } static std::string __native_password_encrypt(const std::string& password, unsigned char seed[20]) { unsigned char buf1[20]; unsigned char buf2[40]; int i; // SHA1( password ) ^ SHA1( seed + SHA1( SHA1( password ) ) ) SHA1((unsigned char *)password.c_str(), password.size(), buf1); SHA1(buf1, 20, buf2 + 20); memcpy(buf2, seed, 20); SHA1(buf2, 40, buf2); for (i = 0; i < 20; i++) buf1[i] ^= buf2[i]; return std::string((const char *)buf1, 20); } static std::string __caching_sha2_password_encrypt(const std::string& password, unsigned char seed[20]) { unsigned char buf1[32]; unsigned char buf2[52]; int i; // SHA256( password ) ^ SHA256( SHA256( SHA256( password ) ) + seed) SHA256((unsigned char *)password.c_str(), password.size(), buf1); SHA256(buf1, 32, buf2); memcpy(buf2 + 32, seed, 20); SHA256(buf2, 52, buf2); for (i = 0; i < 32; i++) buf1[i] ^= buf2[i]; return std::string((const char *)buf1, 32); } int MySQLSSLRequest::encode(struct iovec vectors[], int max) { unsigned char header[32] = {0}; unsigned char *pos = header; int ret; int4store(pos, MYSQL_CAPFLAG_CLIENT_SSL | MYSQL_CAPFLAG_CLIENT_PROTOCOL_41 | MYSQL_CAPFLAG_CLIENT_SECURE_CONNECTION | MYSQL_CAPFLAG_CLIENT_CONNECT_WITH_DB | MYSQL_CAPFLAG_CLIENT_MULTI_RESULTS| MYSQL_CAPFLAG_CLIENT_LOCAL_FILES | MYSQL_CAPFLAG_CLIENT_MULTI_STATEMENTS | MYSQL_CAPFLAG_CLIENT_PS_MULTI_RESULTS | MYSQL_CAPFLAG_CLIENT_PLUGIN_AUTH); pos += 4; int4store(pos, 0); pos += 4; *pos = (uint8_t)character_set_; buf_.clear(); buf_.append((char *)header, 32); ret = MySQLMessage::encode(vectors, max); if (ret >= 0) { max -= ret; if (max >= 8) /* Indeed SSL handshaker needs only 1 vector. */ { max = ssl_handshaker_.encode(vectors + ret, max); if (max >= 0) return max + ret; } else errno = EOVERFLOW; } return -1; } int MySQLAuthRequest::encode(struct iovec vectors[], int max) { unsigned char header[32] = {0}; unsigned char *pos = header; std::string str; int4store(pos, MYSQL_CAPFLAG_CLIENT_PROTOCOL_41 | MYSQL_CAPFLAG_CLIENT_SECURE_CONNECTION | MYSQL_CAPFLAG_CLIENT_CONNECT_WITH_DB | MYSQL_CAPFLAG_CLIENT_MULTI_RESULTS| MYSQL_CAPFLAG_CLIENT_LOCAL_FILES | MYSQL_CAPFLAG_CLIENT_MULTI_STATEMENTS | MYSQL_CAPFLAG_CLIENT_PS_MULTI_RESULTS | MYSQL_CAPFLAG_CLIENT_PLUGIN_AUTH); pos += 4; int4store(pos, 0); pos += 4; *pos = (uint8_t)character_set_; if (password_.empty()) str.push_back(0); else if (auth_plugin_name_ == CACHING_SHA2_PASSWORD) { str.push_back(32); str += __caching_sha2_password_encrypt(password_, seed_); } else { str.push_back(20); str += __native_password_encrypt(password_, seed_); } buf_.clear(); buf_.append((char *)header, 32); buf_.append(username_.c_str(), username_.size() + 1); buf_.append(str); buf_.append(db_.c_str(), db_.size() + 1); if (auth_plugin_name_.size() != 0) buf_.append(auth_plugin_name_.c_str(), auth_plugin_name_.size() + 1); return MySQLMessage::encode(vectors, max); } int MySQLAuthRequest::decode_packet(const unsigned char *buf, size_t buflen) { const unsigned char *end = buf + buflen; const unsigned char *pos; if (buflen < 32) return -2; uint32_t flags = uint4korr(buf); if (!(flags & MYSQL_CAPFLAG_CLIENT_PROTOCOL_41)) return -2; buf += 8; character_set_ = *buf++; buf += 23; pos = buf; while (pos < end && *pos) pos++; if (pos >= end) return -2; username_.assign((const char *)buf, pos - buf); buf = pos + 1; return 1; } int MySQLAuthResponse::decode_packet(const unsigned char *buf, size_t buflen) { const unsigned char *end = buf + buflen; const unsigned char *pos; const unsigned char *str; unsigned long long len; if (end == buf) return -2; switch (*buf) { case 0x00: case 0xff: return MySQLResponse::decode_packet(buf, buflen); case 0xfe: pos = ++buf; while (pos < end && *pos) pos++; if (pos >= end) return -2; auth_plugin_name_.assign((const char *)buf, pos - buf); buf = pos + 1; if (buf == end || *(end - 1) != '\0') return -2; if (end - 1 - buf != 20) return -2; memcpy(seed_, buf, 20); return 1; default: pos = buf; if (decode_string(&str, &len, &pos, end) > 0 && len == 1) { if (*str == 0x03) { if (end > pos) return MySQLResponse::decode_packet(pos, end - pos); else return 0; } else if (*str == 0x04) { continue_ = true; return 1; } } return -2; } } int MySQLAuthSwitchRequest::encode(struct iovec vectors[], int max) { if (password_.empty()) { buf_ = "\0"; } else if (auth_plugin_name_ == MYSQL_NATIVE_PASSWORD) { buf_ = __native_password_encrypt(password_, seed_); } else if (auth_plugin_name_ == CACHING_SHA2_PASSWORD) { buf_ = __caching_sha2_password_encrypt(password_, seed_); } else if (auth_plugin_name_ == MYSQL_CLEAR_PASSWORD) { buf_ = password_; buf_.push_back('\0'); } else { errno = EINVAL; return -1; } return MySQLMessage::encode(vectors, max); } int MySQLPublicKeyResponse::decode_packet(const unsigned char *buf, size_t buflen) { if (buflen == 0 || *buf != 0x01) return -2; if (buflen == 1) return 0; public_key_.assign((const char *)buf + 1, buflen - 1); return 1; } int MySQLPublicKeyResponse::encode(struct iovec vectors[], int max) { buf_.clear(); buf_.push_back(0x01); buf_ += public_key_; return MySQLMessage::encode(vectors, max); } int MySQLRSAAuthRequest::rsa_encrypt(void *ctx) { EVP_PKEY_CTX *pkey_ctx = (EVP_PKEY_CTX *)ctx; unsigned char out[256]; size_t outlen = 256; std::string pass; unsigned char *p; size_t i; if (EVP_PKEY_encrypt_init(pkey_ctx) > 0 && EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, RSA_PKCS1_OAEP_PADDING) > 0) { pass.reserve(password_.size() + 1); p = (unsigned char *)pass.c_str(); for (i = 0; i <= password_.size(); i++) p[i] = (unsigned char)password_[i] ^ seed_[i % 20]; if (EVP_PKEY_encrypt(pkey_ctx, out, &outlen, p, i) > 0) { buf_.assign((char *)out, 256); return 0; } } return -1; } int MySQLRSAAuthRequest::encode(struct iovec vectors[], int max) { BIO *bio; EVP_PKEY *pkey; EVP_PKEY_CTX *pkey_ctx; int ret = -1; bio = BIO_new_mem_buf((void *)public_key_.c_str(), public_key_.size()); if (bio) { pkey = PEM_read_bio_PUBKEY(bio, NULL, NULL, NULL); if (pkey) { pkey_ctx = EVP_PKEY_CTX_new(pkey, NULL); if (pkey_ctx) { ret = rsa_encrypt(pkey_ctx); EVP_PKEY_CTX_free(pkey_ctx); } EVP_PKEY_free(pkey); } BIO_free(bio); } if (ret < 0) return ret; return MySQLMessage::encode(vectors, max); } void MySQLResponse::set_ok_packet() { uint16_t zero16 = 0; buf_.clear(); buf_.push_back(0x00); buf_.append((const char *)&zero16, 2); buf_.append((const char *)&zero16, 2); buf_.append((const char *)&zero16, 2); } int MySQLResponse::decode_packet(const unsigned char *buf, size_t buflen) { return mysql_parser_parse(buf, buflen, parser_); } unsigned long long MySQLResponse::get_affected_rows() const { unsigned long long affected_rows = 0; MySQLResultCursor cursor(this); do { affected_rows += cursor.get_affected_rows(); } while (cursor.next_result_set()); return affected_rows; } // return array api unsigned long long MySQLResponse::get_last_insert_id() const { unsigned long long insert_id = 0; MySQLResultCursor cursor(this); do { if (cursor.get_insert_id()) insert_id = cursor.get_insert_id(); } while (cursor.next_result_set()); return insert_id; } int MySQLResponse::get_warnings() const { int warning_count = 0; MySQLResultCursor cursor(this); do { warning_count += cursor.get_warnings(); } while (cursor.next_result_set()); return warning_count; } std::string MySQLResponse::get_info() const { std::string info; MySQLResultCursor cursor(this); do { if (info.length() > 0) info += " "; info += cursor.get_info(); } while (cursor.next_result_set()); return info; } bool MySQLResponse::is_ok_packet() const { return parser_->packet_type == MYSQL_PACKET_OK; } bool MySQLResponse::is_error_packet() const { return parser_->packet_type == MYSQL_PACKET_ERROR; } } workflow-0.11.8/src/protocol/MySQLMessage.h000066400000000000000000000053661476003635400205500ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #ifndef _MYSQLMESSAGE_H_ #define _MYSQLMESSAGE_H_ #include #include #include "ProtocolMessage.h" #include "mysql_stream.h" #include "mysql_parser.h" /** * @file MySQLMessage.h * @brief MySQL Protocol Interface */ namespace protocol { class MySQLMessage : public ProtocolMessage { public: mysql_parser_t *get_parser() const; int get_seqid() const; void set_seqid(int seqid); int get_command() const; protected: virtual int append(const void *buf, size_t *size); virtual int encode(struct iovec vectors[], int max); virtual int decode_packet(const unsigned char *buf, size_t buflen) { return 1; } void set_command(int cmd) const; //append mysql_stream_t *stream_; mysql_parser_t *parser_; //encode unsigned char heads_[256][4]; uint8_t seqid_; std::string buf_; size_t cur_size_; public: MySQLMessage(); virtual ~MySQLMessage(); //move constructor MySQLMessage(MySQLMessage&& move); //move operator MySQLMessage& operator= (MySQLMessage&& move); }; class MySQLRequest : public MySQLMessage { public: void set_query(const char *query); void set_query(const std::string& query); void set_query(const char *query, size_t length); std::string get_query() const; bool query_is_unset() const; public: MySQLRequest() = default; //move constructor MySQLRequest(MySQLRequest&& move) = default; //move operator MySQLRequest& operator= (MySQLRequest&& move) = default; }; class MySQLResponse : public MySQLMessage { public: bool is_ok_packet() const; bool is_error_packet() const; int get_packet_type() const; unsigned long long get_affected_rows() const; unsigned long long get_last_insert_id() const; int get_warnings() const; int get_error_code() const; std::string get_error_msg() const; std::string get_sql_state() const; std::string get_info() const; void set_ok_packet(); public: MySQLResponse() = default; //move constructor MySQLResponse(MySQLResponse&& move) = default; //move operator MySQLResponse& operator= (MySQLResponse&& move) = default; protected: virtual int decode_packet(const unsigned char *buf, size_t buflen); }; } //impl. not for user #include "MySQLMessage.inl" #endif workflow-0.11.8/src/protocol/MySQLMessage.inl000066400000000000000000000216141476003635400210750ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com) */ #include #include #include #include #include #include "SSLWrapper.h" namespace protocol { class MySQLHandshakeRequest : public MySQLRequest { private: virtual int encode(struct iovec vectors[], int max) { return 0; } }; class MySQLHandshakeResponse : public MySQLResponse { public: std::string get_server_version() const { return server_version_; } std::string get_auth_plugin_name() const { return auth_plugin_name_; } void get_seed(unsigned char seed[20]) const { memcpy(seed, auth_plugin_data_, 20); } virtual int encode(struct iovec vectors[], int max); void server_set(uint8_t protocol_version, const std::string server_version, uint32_t connection_id, const unsigned char seed[20], uint32_t capability_flags, uint8_t character_set, uint16_t status_flags) { protocol_version_ = protocol_version; server_version_ = server_version; connection_id_ = connection_id; memcpy(auth_plugin_data_, seed, 20); capability_flags_ = capability_flags; character_set_ = character_set; status_flags_ = status_flags; } bool host_disallowed() const { return disallowed_; } uint32_t get_capability_flags() const { return capability_flags_; } uint16_t get_status_flags() const { return status_flags_; } private: virtual int decode_packet(const unsigned char *buf, size_t buflen); std::string server_version_; std::string auth_plugin_name_; unsigned char auth_plugin_data_[20]; uint32_t connection_id_; uint32_t capability_flags_; uint16_t status_flags_; uint8_t character_set_; uint8_t auth_plugin_data_len_; uint8_t protocol_version_; bool disallowed_; public: MySQLHandshakeResponse() : disallowed_(false) { } //move constructor MySQLHandshakeResponse(MySQLHandshakeResponse&& move) = default; //move operator MySQLHandshakeResponse& operator= (MySQLHandshakeResponse&& move) = default; }; class MySQLSSLRequest : public MySQLRequest { private: virtual int encode(struct iovec vectors[], int max); /* Do not support server side with SSL currently. */ virtual int decode_packet(const unsigned char *buf, size_t buflen) { return -2; } private: int character_set_; SSLHandshaker ssl_handshaker_; public: MySQLSSLRequest(int character_set, SSL *ssl) : ssl_handshaker_(ssl) { character_set_ = character_set; } MySQLSSLRequest(MySQLSSLRequest&& move) = default; MySQLSSLRequest& operator= (MySQLSSLRequest&& move) = default; }; class MySQLAuthRequest : public MySQLRequest { public: void set_auth(const std::string username, const std::string password, const std::string db, int character_set) { username_ = std::move(username); password_ = std::move(password); db_ = std::move(db); character_set_ = character_set; } void set_auth_plugin_name(std::string name) { auth_plugin_name_ = std::move(name); } void set_seed(const unsigned char seed[20]) { memcpy(seed_, seed, 20); } private: virtual int encode(struct iovec vectors[], int max); virtual int decode_packet(const unsigned char *buf, size_t buflen); std::string username_; std::string password_; std::string db_; std::string auth_plugin_name_; unsigned char seed_[20]; int character_set_; public: MySQLAuthRequest() : character_set_(33) { } //move constructor MySQLAuthRequest(MySQLAuthRequest&& move) = default; //move operator MySQLAuthRequest& operator= (MySQLAuthRequest&& move) = default; }; class MySQLAuthResponse : public MySQLResponse { public: std::string get_auth_plugin_name() const { return auth_plugin_name_; } void get_seed(unsigned char seed[20]) const { memcpy(seed, seed_, 20); } bool is_continue() const { return continue_; } private: virtual int decode_packet(const unsigned char *buf, size_t buflen); private: std::string auth_plugin_name_; unsigned char seed_[20]; bool continue_; public: MySQLAuthResponse() : continue_(false) { } //move constructor MySQLAuthResponse(MySQLAuthResponse&& move) = default; //move operator MySQLAuthResponse& operator= (MySQLAuthResponse&& move) = default; }; class MySQLAuthSwitchRequest : public MySQLRequest { public: void set_password(std::string password) { password_ = std::move(password); } void set_auth_plugin_name(std::string name) { auth_plugin_name_ = std::move(name); } void set_seed(const unsigned char seed[20]) { memcpy(seed_, seed, 20); } private: virtual int encode(struct iovec vectors[], int max); /* Not implemented. */ virtual int decode_packet(const unsigned char *buf, size_t buflen) { return -2; } std::string password_; std::string auth_plugin_name_; unsigned char seed_[20]; public: MySQLAuthSwitchRequest() { } //move constructor MySQLAuthSwitchRequest(MySQLAuthSwitchRequest&& move) = default; //move operator MySQLAuthSwitchRequest& operator= (MySQLAuthSwitchRequest&& move) = default; }; class MySQLPublicKeyRequest : public MySQLRequest { public: void set_caching_sha2() { byte_ = 0x02; } void set_sha256() { byte_ = 0x01; } private: virtual int encode(struct iovec vectors[], int max) { buf_.assign(&byte_, 1); return MySQLRequest::encode(vectors, max); } /* Not implemented. */ virtual int decode_packet(const unsigned char *buf, size_t buflen) { return -2; } char byte_; public: MySQLPublicKeyRequest() : byte_(0x01) { } //move constructor MySQLPublicKeyRequest(MySQLPublicKeyRequest&& move) = default; //move operator MySQLPublicKeyRequest& operator= (MySQLPublicKeyRequest&& move) = default; }; class MySQLPublicKeyResponse : public MySQLResponse { public: std::string get_public_key() const { return public_key_; } void set_public_key(std::string key) { public_key_ = std::move(key); } private: virtual int encode(struct iovec vectors[], int max); virtual int decode_packet(const unsigned char *buf, size_t buflen); std::string public_key_; public: MySQLPublicKeyResponse() { } //move constructor MySQLPublicKeyResponse(MySQLPublicKeyResponse&& move) = default; //move operator MySQLPublicKeyResponse& operator= (MySQLPublicKeyResponse&& move) = default; }; class MySQLRSAAuthRequest : public MySQLRequest { public: void set_password(std::string password) { password_ = std::move(password); } void set_public_key(std::string key) { public_key_ = std::move(key); } void set_seed(const unsigned char seed[20]) { memcpy(seed_, seed, 20); } private: virtual int encode(struct iovec vectors[], int max); /* Not implemented. */ virtual int decode_packet(const unsigned char *buf, size_t buflen) { return -2; } int rsa_encrypt(void *ctx); std::string password_; std::string public_key_; unsigned char seed_[20]; public: MySQLRSAAuthRequest() { } //move constructor MySQLRSAAuthRequest(MySQLRSAAuthRequest&& move) = default; //move operator MySQLRSAAuthRequest& operator= (MySQLRSAAuthRequest&& move) = default; }; ////////// inline mysql_parser_t *MySQLMessage::get_parser() const { return parser_; } inline int MySQLMessage::get_seqid() const { return seqid_; } inline void MySQLMessage::set_seqid(int seqid) { seqid_ = seqid; } inline int MySQLMessage::get_command() const { return parser_->cmd; } inline void MySQLMessage::set_command(int cmd) const { mysql_parser_set_command(cmd, parser_); } inline MySQLMessage::MySQLMessage(): stream_(new mysql_stream_t), parser_(new mysql_parser_t), seqid_(0), cur_size_(0) { mysql_stream_init(stream_); mysql_parser_init(parser_); } inline bool MySQLRequest::query_is_unset() const { return buf_.empty(); } inline void MySQLRequest::set_query(const char *query) { set_query(query, strlen(query)); } inline void MySQLRequest::set_query(const std::string& query) { set_query(query.c_str(), query.size()); } inline int MySQLResponse::get_packet_type() const { return parser_->packet_type; } inline int MySQLResponse::get_error_code() const { return is_error_packet() ? parser_->error : 0; } inline std::string MySQLResponse::get_error_msg() const { if (is_error_packet()) { const char *s; size_t slen; mysql_parser_get_err_msg(&s, &slen, parser_); if (slen > 0) return std::string(s, slen); } return std::string(); } inline std::string MySQLResponse::get_sql_state() const { if (is_error_packet()) { const char *s; size_t slen; mysql_parser_get_net_state(&s, &slen, parser_); if (slen > 0) return std::string(s, slen); } return std::string(); } } workflow-0.11.8/src/protocol/MySQLResult.cc000066400000000000000000000175321476003635400205760ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) */ #include #include "mysql_types.h" #include "mysql_byteorder.h" #include "MySQLMessage.h" #include "MySQLResult.h" namespace protocol { MySQLField::MySQLField(const void *buf, mysql_field_t *field) { const char *p = (const char *)buf; this->name = p + field->name_offset; this->name_length = field->name_length; this->org_name = p + field->org_name_offset; this->org_name_length = field->org_name_length; this->table = p + field->table_offset; this->table_length = field->table_length; this->org_table = p + field->org_table_offset; this->org_table_length = field->org_table_length; this->db = p + field->db_offset; this->db_length = field->db_length; this->catalog = p + field->catalog_offset; this->catalog_length = field->catalog_length; if (field->def_offset == (size_t)-1 && field->def_length == 0) { this->def = NULL; this->def_length = 0; } else { this->def = p + field->def_offset; this->def_length = field->def_length; } this->flags = field->flags; this->length = field->length; this->decimals = field->decimals; this->charsetnr = field->charsetnr; this->data_type = field->data_type; } MySQLResultCursor::MySQLResultCursor() { this->init(); } void MySQLResultCursor::init() { this->current_field = 0; this->current_row = 0; this->field_count = 0; this->fields = NULL; this->parser = NULL; this->status = MYSQL_STATUS_NOT_INIT; } MySQLResultCursor::MySQLResultCursor(const MySQLResponse *resp) { this->init(resp); } void MySQLResultCursor::reset(MySQLResponse *resp) { this->clear(); this->init(resp); } void MySQLResultCursor::fetch_result_set(const struct __mysql_result_set *result_set) { const char *buf = (const char *)this->parser->buf; this->server_status = result_set->server_status; switch (result_set->type) { case MYSQL_PACKET_GET_RESULT: this->status = MYSQL_STATUS_GET_RESULT; this->field_count = result_set->field_count; this->start = buf + result_set->rows_begin_offset; this->pos = this->start; this->end = buf + result_set->rows_end_offset; this->row_count = result_set->row_count; this->fields = new MySQLField *[this->field_count]; for (int i = 0; i < this->field_count; i++) this->fields[i] = new MySQLField(this->parser->buf, result_set->fields[i]); break; case MYSQL_PACKET_OK: this->status = MYSQL_STATUS_OK; this->affected_rows = result_set->affected_rows; this->insert_id = result_set->insert_id; this->warning_count = result_set->warning_count; this->start = buf + result_set->info_offset; this->info_len = result_set->info_len; this->field_count = 0; this->fields = NULL; break; default: this->status = MYSQL_STATUS_ERROR; break; } } void MySQLResultCursor::init(const MySQLResponse *resp) { this->current_field = 0; this->current_row = 0; this->field_count = 0; this->fields = NULL; this->parser = resp->get_parser(); this->status = MYSQL_STATUS_NOT_INIT; if (!list_empty(&this->parser->result_set_list)) { struct __mysql_result_set *result_set; mysql_result_set_cursor_init(&this->cursor, this->parser); mysql_result_set_cursor_next(&result_set, &this->cursor); this->fetch_result_set(result_set); } } bool MySQLResultCursor::next_result_set() { if (this->status == MYSQL_STATUS_NOT_INIT || this->status == MYSQL_STATUS_ERROR) { return false; } struct __mysql_result_set *result_set; if (mysql_result_set_cursor_next(&result_set, &this->cursor) == 0) { for (int i = 0; i < this->field_count; i++) delete this->fields[i]; delete []this->fields; this->current_field = 0; this->current_row = 0; this->fetch_result_set(result_set); return true; } else { this->status = MYSQL_STATUS_END; return false; } } bool MySQLResultCursor::fetch_row(std::vector& row_arr) { if (this->status != MYSQL_STATUS_GET_RESULT) return false; unsigned long long len; const unsigned char *data; int data_type; const unsigned char *p = (const unsigned char *)this->pos; const unsigned char *end = (const unsigned char *)this->end; row_arr.clear(); for (int i = 0; i < this->field_count; i++) { data_type = this->fields[i]->get_data_type(); if (*p == MYSQL_PACKET_HEADER_NULL) { data = NULL; len = 0; p++; data_type = MYSQL_TYPE_NULL; } else if (decode_string(&data, &len, &p, end) == 0) { this->status = MYSQL_STATUS_ERROR; return false; } row_arr.emplace_back(data, len, data_type); } if (++this->current_row == this->row_count) this->status = MYSQL_STATUS_END; this->pos = p; return true; } bool MySQLResultCursor::fetch_row(std::map& row_map) { return this->fetch_row>(row_map); } bool MySQLResultCursor::fetch_row(std::unordered_map& row_map) { return this->fetch_row>(row_map); } bool MySQLResultCursor::fetch_row_nocopy(const void **data, size_t *len, int *data_type) { if (this->status != MYSQL_STATUS_GET_RESULT) return false; unsigned long long cell_len; const unsigned char *cell_data; const unsigned char *p = (const unsigned char *)this->pos; const unsigned char *end = (const unsigned char *)this->end; for (int i = 0; i < this->field_count; i++) { if (*p == MYSQL_PACKET_HEADER_NULL) { cell_data = NULL; cell_len = 0; p++; } else if (decode_string(&cell_data, &cell_len, &p, end) == 0) { this->status = MYSQL_STATUS_ERROR; return false; } data[i] = cell_data; len[i] = cell_len; data_type[i] = this->fields[i]->get_data_type(); } this->pos = p; if (++this->current_row == this->row_count) this->status = MYSQL_STATUS_END; return true; } bool MySQLResultCursor::fetch_all(std::vector>& rows) { if (this->status != MYSQL_STATUS_GET_RESULT) return false; unsigned long long len; const unsigned char *data; int data_type; const unsigned char *p = (const unsigned char *)this->pos; const unsigned char *end = (const unsigned char *)this->end; rows.clear(); for (int i = this->current_row; i < this->row_count; i++) { std::vector tmp; for (int j = 0; j < this->field_count; j++) { data_type = this->fields[j]->get_data_type(); if (*p == MYSQL_PACKET_HEADER_NULL) { data = NULL; len = 0; p++; data_type = MYSQL_TYPE_NULL; } else if (decode_string(&data, &len, &p, end) == 0) { this->status = MYSQL_STATUS_ERROR; return false; } tmp.emplace_back(data, len, data_type); } rows.emplace_back(std::move(tmp)); } this->current_row = this->row_count; this->status = MYSQL_STATUS_END; this->pos = p; return true; } void MySQLResultCursor::first_result_set() { if (this->status == MYSQL_STATUS_NOT_INIT || this->status == MYSQL_STATUS_ERROR) { return; } mysql_result_set_cursor_rewind(&this->cursor); struct __mysql_result_set *result_set; if (mysql_result_set_cursor_next(&result_set, &this->cursor) == 0) { for (int i = 0; i < this->field_count; i++) delete this->fields[i]; delete []this->fields; this->current_field = 0; this->current_row = 0; this->fetch_result_set(result_set); } } void MySQLResultCursor::rewind() { if (this->status != MYSQL_STATUS_GET_RESULT && this->status != MYSQL_STATUS_END) { return; } this->current_field = 0; this->current_row = 0; this->pos = this->start; } } workflow-0.11.8/src/protocol/MySQLResult.h000066400000000000000000000110731476003635400204320ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) */ #ifndef _MYSQLRESULT_H_ #define _MYSQLRESULT_H_ #include #include #include #include #include "mysql_types.h" #include "mysql_parser.h" #include "MySQLMessage.h" /** * @file MySQLResult.h * @brief MySQL toolbox for visit result */ namespace protocol { class MySQLCell { public: MySQLCell(); MySQLCell(MySQLCell&& move); MySQLCell& operator=(MySQLCell&& move); MySQLCell(const void *data, size_t len, int data_type); int get_data_type() const; bool is_null() const; bool is_int() const; bool is_string() const; bool is_float() const; bool is_double() const; bool is_ulonglong() const; bool is_date() const; bool is_time() const; bool is_datetime() const; // for copy int as_int() const; std::string as_string() const; std::string as_binary_string() const; float as_float() const; double as_double() const; unsigned long long as_ulonglong() const; std::string as_date() const; std::string as_time() const; std::string as_datetime() const; // for nocopy void get_cell_nocopy(const void **data, size_t *len, int *data_type) const; private: int data_type; void *data; size_t len; }; class MySQLField { public: MySQLField(const void *buf, mysql_field_t *field); std::string get_name() const; std::string get_org_name() const; std::string get_table() const; std::string get_org_table() const; std::string get_db() const; std::string get_catalog() const; std::string get_def() const; int get_charsetnr() const; int get_length() const; int get_flags() const; int get_decimals() const; int get_data_type() const; private: const char *name; /* Name of column */ const char *org_name; /* Original column name, if an alias */ const char *table; /* Table of column if column was a field */ const char *org_table; /* Org table name, if table was an alias */ const char *db; /* Database for table */ const char *catalog; /* Catalog for table */ const char *def; /* Default value (set by mysql_list_fields) */ int length; /* Width of column (create length) */ int name_length; int org_name_length; int table_length; int org_table_length; int db_length; int catalog_length; int def_length; int flags; /* Div flags */ int decimals; /* Number of decimals in field */ int charsetnr; /* Character set */ int data_type; /* Type of field. See mysql_types.h for types */ }; class MySQLResultCursor { public: MySQLResultCursor(const MySQLResponse *resp); MySQLResultCursor(MySQLResultCursor&& move); MySQLResultCursor& operator=(MySQLResultCursor&& move); virtual ~MySQLResultCursor(); bool next_result_set(); void first_result_set(); const MySQLField *fetch_field(); const MySQLField *const *fetch_fields() const; bool fetch_row(std::vector& row_arr); bool fetch_row(std::map& row_map); bool fetch_row(std::unordered_map& row_map); bool fetch_row_nocopy(const void **data, size_t *len, int *data_type); bool fetch_all(std::vector>& rows); int get_cursor_status() const; int get_server_status() const; int get_field_count() const; int get_rows_count() const; unsigned long long get_affected_rows() const; unsigned long long get_insert_id() const; int get_warnings() const; std::string get_info() const; void rewind(); public: MySQLResultCursor(); void reset(MySQLResponse *resp); private: void init(const MySQLResponse *resp); void init(); void clear(); void fetch_result_set(const struct __mysql_result_set *result_set); template bool fetch_row(T& row_map); int status; int server_status; const void *start; const void *end; const void *pos; const void **row_data; MySQLField **fields; int row_count; int field_count; int current_row; int current_field; unsigned long long affected_rows; unsigned long long insert_id; int warning_count; int info_len; mysql_result_set_cursor_t cursor; mysql_parser_t *parser; }; } #include "MySQLResult.inl" #endif workflow-0.11.8/src/protocol/MySQLResult.inl000066400000000000000000000236761476003635400210010ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) */ #include #include #include #include "mysql_byteorder.h" namespace protocol { inline std::string MySQLField::get_name() const { if (this->data_type == MYSQL_TYPE_NULL) return ""; return std::string(this->name, this->name_length); } inline std::string MySQLField::get_org_name() const { if (this->data_type == MYSQL_TYPE_NULL) return ""; return std::string(this->org_name, this->org_name_length); } inline std::string MySQLField::get_table() const { if (this->data_type == MYSQL_TYPE_NULL) return ""; return std::string(this->table, this->table_length); } inline std::string MySQLField::get_org_table() const { if (this->data_type == MYSQL_TYPE_NULL) return ""; return std::string(this->org_table, this->org_table_length); } inline std::string MySQLField::get_db() const { if (this->data_type == MYSQL_TYPE_NULL) return ""; return std::string(this->db, this->db_length); } inline std::string MySQLField::get_catalog() const { if (this->data_type == MYSQL_TYPE_NULL) return ""; return std::string(this->catalog, this->catalog_length); } inline std::string MySQLField::get_def() const { if (this->data_type == MYSQL_TYPE_NULL) return ""; return std::string(this->def, this->def_length); } inline int MySQLField::get_charsetnr() const { if (this->data_type == MYSQL_TYPE_NULL) return 0; return this->charsetnr; } inline int MySQLField::get_length() const { if (this->data_type == MYSQL_TYPE_NULL) return 0; return this->length; } inline int MySQLField::get_flags() const { if (this->data_type == MYSQL_TYPE_NULL) return 0; return this->flags; } inline int MySQLField::get_decimals() const { if (this->data_type == MYSQL_TYPE_NULL) return 0; return this->decimals; } inline int MySQLField::get_data_type() const { return this->data_type; } inline MySQLCell::MySQLCell(MySQLCell&& move) { this->operator=(std::move(move)); } inline MySQLCell& MySQLCell::operator=(MySQLCell&& move) { if (this != &move) { this->data = move.data; this->len = move.len; this->data_type = move.data_type; move.data = NULL; move.len = 0; } return *this; } inline MySQLCell::MySQLCell(const void *data, size_t len, int data_type) { this->data_type = data_type; this->data = const_cast(data); this->len = len; } inline MySQLCell::MySQLCell() { this->data = NULL; this->len = 0; this->data_type = MYSQL_TYPE_NULL; } inline int MySQLCell::get_data_type() const { return this->data_type; } inline void MySQLCell::get_cell_nocopy(const void **data, size_t *len, int *data_type) const { *data = this->data; *len = this->len; *data_type = this->data_type; } inline bool MySQLCell::is_null() const { return (this->data_type == MYSQL_TYPE_NULL); } inline std::string MySQLCell::as_binary_string() const { return std::string((char *)this->data, this->len); } inline bool MySQLCell::is_int() const { return (this->data_type == MYSQL_TYPE_TINY || this->data_type == MYSQL_TYPE_SHORT || this->data_type == MYSQL_TYPE_INT24 || this->data_type == MYSQL_TYPE_LONG); } inline int MySQLCell::as_int() const { if (!this->is_int()) return 0; std::string num((char *)this->data, this->len); return atoi(num.c_str()); } inline bool MySQLCell::is_float() const { return (this->data_type == MYSQL_TYPE_FLOAT); } inline float MySQLCell::as_float() const { if (!this->is_float()) return NAN; std::string num((char *)this->data, this->len); return strtof(num.c_str(), NULL); } inline bool MySQLCell::is_double() const { return (this->data_type == MYSQL_TYPE_DOUBLE); } inline double MySQLCell::as_double() const { if (!this->is_double()) return NAN; std::string num((char *)this->data, this->len); return strtod(num.c_str(), NULL); } inline bool MySQLCell::is_ulonglong() const { return (this->data_type == MYSQL_TYPE_LONGLONG); } inline unsigned long long MySQLCell::as_ulonglong() const { if (!this->is_ulonglong()) return (unsigned long long)-1; std::string num((char *)this->data, this->len); return strtoull(num.c_str(), NULL, 10); } inline bool MySQLCell::is_date() const { return (this->data_type == MYSQL_TYPE_DATE); } inline std::string MySQLCell::as_date() const { if (!this->is_date()) return ""; return std::string((char *)this->data, this->len); } inline bool MySQLCell::is_time() const { return (this->data_type == MYSQL_TYPE_TIME); } inline std::string MySQLCell::as_time() const { if (!this->is_time()) return ""; return std::string((char *)this->data, this->len); } inline bool MySQLCell::is_datetime() const { return (this->data_type == MYSQL_TYPE_DATETIME || this->data_type == MYSQL_TYPE_TIMESTAMP); } inline std::string MySQLCell::as_datetime() const { if (!this->is_datetime()) return ""; return std::string((char *)this->data, this->len); } inline bool MySQLCell::is_string() const { return (this->data_type == MYSQL_TYPE_DECIMAL || this->data_type == MYSQL_TYPE_NEWDECIMAL || this->data_type == MYSQL_TYPE_STRING || this->data_type == MYSQL_TYPE_VARCHAR || this->data_type == MYSQL_TYPE_VAR_STRING || this->data_type == MYSQL_TYPE_JSON); } inline std::string MySQLCell::as_string() const { if (!this->is_string() && !this->is_time() && !this->is_date() && !this->is_datetime()) return ""; return std::string((char *)this->data, this->len); } template bool MySQLResultCursor::fetch_row(T& row_map) { if (this->status != MYSQL_STATUS_GET_RESULT) return false; unsigned long long len; const unsigned char *data; int data_type; const unsigned char *p = (const unsigned char *)this->pos; const unsigned char *end = (const unsigned char *)this->end; row_map.clear(); for (int i = 0; i < this->field_count; i++) { data_type = this->fields[i]->get_data_type(); if (*p == MYSQL_PACKET_HEADER_NULL) { data = NULL; len = 0; p++; data_type = MYSQL_TYPE_NULL; } else if (decode_string(&data, &len, &p, end) == 0) { this->status = MYSQL_STATUS_ERROR; return false; } row_map.emplace(this->fields[i]->get_name(), MySQLCell(data, len, data_type)); } this->pos = p; if (++this->current_row == this->row_count) this->status = MYSQL_STATUS_END; return true; } inline const MySQLField *MySQLResultCursor::fetch_field() { if (this->status != MYSQL_STATUS_GET_RESULT && this->status != MYSQL_STATUS_END) { return NULL; } if (this->current_field >= this->field_count) return NULL; return this->fields[this->current_field++]; } inline const MySQLField *const *MySQLResultCursor::fetch_fields() const { if (this->status != MYSQL_STATUS_GET_RESULT && this->status != MYSQL_STATUS_END) { return NULL; } return this->fields; } inline int MySQLResultCursor::get_cursor_status() const { return this->status; } inline int MySQLResultCursor::get_server_status() const { if (this->status != MYSQL_STATUS_GET_RESULT && this->status != MYSQL_STATUS_END && this->status != MYSQL_STATUS_OK) { return 0; } return this->server_status; } inline int MySQLResultCursor::get_field_count() const { if (this->status != MYSQL_STATUS_GET_RESULT && this->status != MYSQL_STATUS_END) { return 0; } return this->field_count; } inline int MySQLResultCursor::get_rows_count() const { if (this->status != MYSQL_STATUS_GET_RESULT && this->status != MYSQL_STATUS_END) { return 0; } return this->row_count; } inline unsigned long long MySQLResultCursor::get_affected_rows() const { if (this->status != MYSQL_PACKET_OK) return 0; return this->affected_rows; } inline int MySQLResultCursor::get_warnings() const { if (this->status != MYSQL_PACKET_OK) return 0; return this->warning_count; } inline unsigned long long MySQLResultCursor::get_insert_id() const { if (this->status != MYSQL_PACKET_OK) return 0; return this->insert_id; } inline std::string MySQLResultCursor::get_info() const { if (this->status != MYSQL_PACKET_OK) return ""; return std::string((char *)this->start, this->info_len); } inline void MySQLResultCursor::clear() { for (int i = 0; i < this->field_count; i++) delete this->fields[i]; delete []this->fields; } inline MySQLResultCursor::~MySQLResultCursor() { this->clear(); } inline MySQLResultCursor::MySQLResultCursor(MySQLResultCursor&& move) { this->start = move.start; this->end = move.end; this->pos = move.pos; this->status = move.status; this->row_data = move.row_data; this->fields = move.fields; this->row_count = move.row_count; this->field_count = move.field_count; this->current_row = move.current_row; this->current_field = move.current_field; this->affected_rows = move.affected_rows; this->insert_id = move.insert_id; this->warning_count = move.warning_count; this->info_len = move.info_len; this->cursor = move.cursor; this->parser = move.parser; move.init(); } inline MySQLResultCursor& MySQLResultCursor::operator=(MySQLResultCursor&& move) { if (this != &move) { this->clear(); this->start = move.start; this->end = move.end; this->pos = move.pos; this->status = move.status; this->row_data = move.row_data; this->fields = move.fields; this->row_count = move.row_count; this->field_count = move.field_count; this->current_row = move.current_row; this->current_field = move.current_field; this->affected_rows = move.affected_rows; this->insert_id = move.insert_id; this->warning_count = move.warning_count; this->info_len = move.info_len; this->cursor = move.cursor; this->parser = move.parser; move.init(); } return *this; } } workflow-0.11.8/src/protocol/MySQLUtil.cc000066400000000000000000000027441476003635400202340ustar00rootroot00000000000000/* Copyright (c) 2022 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include "MySQLUtil.h" namespace protocol { std::string MySQLUtil::escape_string(const std::string& str) { std::string res; char escape; size_t i; for (i = 0; i < str.size(); i++) { switch (str[i]) { case '\0': escape = '0'; break; case '\n': escape = 'n'; break; case '\r': escape = 'r'; break; case '\\': escape = '\\'; break; case '\'': escape = '\''; break; case '\"': escape = '\"'; break; case '\032': escape = 'Z'; break; default: res.push_back(str[i]); continue; } res.push_back('\\'); res.push_back(escape); } return res; } std::string MySQLUtil::escape_string_quote(const std::string& str, char quote) { std::string res; size_t i; for (i = 0; i < str.size(); i++) { if (str[i] == quote) res.push_back(quote); res.push_back(str[i]); } return res; } } workflow-0.11.8/src/protocol/MySQLUtil.h000066400000000000000000000015701476003635400200720ustar00rootroot00000000000000/* Copyright (c) 2022 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _MYSQLUTIL_H_ #define _MYSQLUTIL_H_ #include namespace protocol { class MySQLUtil { public: static std::string escape_string(const std::string& str); static std::string escape_string_quote(const std::string& str, char quote); }; } #endif workflow-0.11.8/src/protocol/PackageWrapper.cc000066400000000000000000000025601476003635400213210ustar00rootroot00000000000000/* Copyright (c) 2022 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include "PackageWrapper.h" namespace protocol { int PackageWrapper::encode(struct iovec vectors[], int max) { int cnt = 0; int ret; while (max >= 8) { ret = this->ProtocolWrapper::encode(vectors, max); if ((unsigned int)ret > (unsigned int)max) { if (ret < 0) return ret; break; } cnt += ret; this->set_message(this->next_out(this->message)); if (!this->message) return cnt; vectors += ret; max -= ret; } errno = EOVERFLOW; return -1; } int PackageWrapper::append(const void *buf, size_t *size) { int ret = this->ProtocolWrapper::append(buf, size); if (ret > 0) { this->set_message(this->next_in(this->message)); if (this->message) { this->renew(); ret = 0; } } return ret; } } workflow-0.11.8/src/protocol/PackageWrapper.h000066400000000000000000000024221476003635400211600ustar00rootroot00000000000000/* Copyright (c) 2022 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _PACKAGEWRAPPER_H_ #define _PACKAGEWRAPPER_H_ #include "ProtocolMessage.h" namespace protocol { class PackageWrapper : public ProtocolWrapper { private: virtual ProtocolMessage *next_out(ProtocolMessage *message) { return NULL; } virtual ProtocolMessage *next_in(ProtocolMessage *message) { return NULL; } protected: virtual int encode(struct iovec vectors[], int max); virtual int append(const void *buf, size_t *size); public: PackageWrapper(ProtocolMessage *message) : ProtocolWrapper(message) { } public: PackageWrapper(PackageWrapper&& wrapper) = default; PackageWrapper& operator = (PackageWrapper&& wrapper) = default; }; } #endif workflow-0.11.8/src/protocol/ProtocolMessage.h000066400000000000000000000076471476003635400214100ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _PROTOCOLMESSAGE_H_ #define _PROTOCOLMESSAGE_H_ #include #include #include #include "Communicator.h" /** * @file ProtocolMessage.h * @brief General Protocol Interface */ namespace protocol { class ProtocolMessage : public CommMessageOut, public CommMessageIn { protected: virtual int encode(struct iovec vectors[], int max) { errno = ENOSYS; return -1; } /* You have to implement one of the 'append' functions, and the first one * with arguement 'size_t *size' is recommmended. */ /* Argument 'size' indicates bytes to append, and returns bytes used. */ virtual int append(const void *buf, size_t *size) { return this->append(buf, *size); } /* When implementing this one, all bytes are consumed. Cannot support * streaming protocol. */ virtual int append(const void *buf, size_t size) { errno = ENOSYS; return -1; } public: void set_size_limit(size_t limit) { this->size_limit = limit; } size_t get_size_limit() const { return this->size_limit; } public: class Attachment { public: virtual ~Attachment() { } }; void set_attachment(Attachment *att) { this->attachment = att; } Attachment *get_attachment() const { return this->attachment; } protected: virtual int feedback(const void *buf, size_t size) { if (this->wrapper) return this->wrapper->feedback(buf, size); else return this->CommMessageIn::feedback(buf, size); } virtual void renew() { if (this->wrapper) return this->wrapper->renew(); else return this->CommMessageIn::renew(); } virtual ProtocolMessage *inner() { return this; } protected: size_t size_limit; private: Attachment *attachment; ProtocolMessage *wrapper; public: ProtocolMessage() { this->size_limit = (size_t)-1; this->attachment = NULL; this->wrapper = NULL; } virtual ~ProtocolMessage() { delete this->attachment; } public: ProtocolMessage(ProtocolMessage&& message) { this->size_limit = message.size_limit; this->attachment = message.attachment; message.attachment = NULL; this->wrapper = NULL; } ProtocolMessage& operator = (ProtocolMessage&& message) { if (&message != this) { this->size_limit = message.size_limit; delete this->attachment; this->attachment = message.attachment; message.attachment = NULL; } return *this; } friend class ProtocolWrapper; }; class ProtocolWrapper : public ProtocolMessage { protected: virtual int encode(struct iovec vectors[], int max) { return this->message->encode(vectors, max); } virtual int append(const void *buf, size_t *size) { return this->message->append(buf, size); } protected: virtual ProtocolMessage *inner() { return this->message->inner(); } protected: void set_message(ProtocolMessage *message) { this->message = message; if (message) message->wrapper = this; } protected: ProtocolMessage *message; public: ProtocolWrapper(ProtocolMessage *message) { this->set_message(message); } public: ProtocolWrapper(ProtocolWrapper&& wrapper) : ProtocolMessage(std::move(wrapper)) { this->set_message(wrapper.message); wrapper.message = NULL; } ProtocolWrapper& operator = (ProtocolWrapper&& wrapper) { if (&wrapper != this) { *(ProtocolMessage *)this = std::move(wrapper); this->set_message(wrapper.message); wrapper.message = NULL; } return *this; } }; } #endif workflow-0.11.8/src/protocol/RedisMessage.cc000066400000000000000000000314701476003635400210020ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) Liu Kai (liukaidx@sogou-inc.com) */ #include #include #include #include #include "EncodeStream.h" #include "RedisMessage.h" namespace protocol { typedef int64_t Rint; typedef std::string Rstr; typedef std::vector Rarr; RedisValue& RedisValue::operator= (const RedisValue& copy) { if (this != ©) { free_data(); switch (copy.type_) { case REDIS_REPLY_TYPE_INTEGER: type_ = copy.type_; data_ = new Rint(*((Rint*)(copy.data_))); break; case REDIS_REPLY_TYPE_ERROR: case REDIS_REPLY_TYPE_STATUS: case REDIS_REPLY_TYPE_STRING: type_ = copy.type_; data_ = new Rstr(*((Rstr*)(copy.data_))); break; case REDIS_REPLY_TYPE_ARRAY: type_ = copy.type_; data_ = new Rarr(*((Rarr*)(copy.data_))); break; default: type_ = REDIS_REPLY_TYPE_NIL; data_ = NULL; } } return *this; } RedisValue& RedisValue::operator= (RedisValue&& move) { if (this != &move) { free_data(); type_ = move.type_; data_ = move.data_; move.type_ = REDIS_REPLY_TYPE_NIL; move.data_ = NULL; } return *this; } void RedisValue::free_data() { if (data_) { switch (type_) { case REDIS_REPLY_TYPE_INTEGER: delete (Rint *)data_; break; case REDIS_REPLY_TYPE_ERROR: case REDIS_REPLY_TYPE_STATUS: case REDIS_REPLY_TYPE_STRING: delete (Rstr *)data_; break; case REDIS_REPLY_TYPE_ARRAY: delete (Rarr *)data_; break; } data_ = NULL; } } void RedisValue::only_set_string_data(const std::string& strv) { Rstr *p = (Rstr *)(data_); p->assign(strv); } void RedisValue::only_set_string_data(const char *str, size_t len) { Rstr *p = (Rstr *)(data_); if (str == NULL || len == 0) p->clear(); else p->assign(str, len); } void RedisValue::only_set_string_data(const char *str) { Rstr *p = (Rstr *)(data_); if (str == NULL) p->clear(); else p->assign(str); } void RedisValue::set_int(int64_t intv) { if (type_ == REDIS_REPLY_TYPE_INTEGER) *((Rint *)data_) = intv; else { free_data(); data_ = new Rint(intv); type_ = REDIS_REPLY_TYPE_INTEGER; } } void RedisValue::set_string(const std::string& strv) { if (type_ == REDIS_REPLY_TYPE_STRING || type_ == REDIS_REPLY_TYPE_STATUS || type_ == REDIS_REPLY_TYPE_ERROR) only_set_string_data(strv); else { free_data(); data_ = new Rstr(strv); } type_ = REDIS_REPLY_TYPE_STRING; } void RedisValue::set_status(const std::string& strv) { if (type_ == REDIS_REPLY_TYPE_STRING || type_ == REDIS_REPLY_TYPE_STATUS || type_ == REDIS_REPLY_TYPE_ERROR) only_set_string_data(strv); else { free_data(); data_ = new Rstr(strv); } type_ = REDIS_REPLY_TYPE_STATUS; } void RedisValue::set_error(const std::string& strv) { if (type_ == REDIS_REPLY_TYPE_STRING || type_ == REDIS_REPLY_TYPE_STATUS || type_ == REDIS_REPLY_TYPE_ERROR) only_set_string_data(strv); else { free_data(); data_ = new Rstr(strv); } type_ = REDIS_REPLY_TYPE_ERROR; } void RedisValue::set_string(const char *str, size_t len) { if (type_ == REDIS_REPLY_TYPE_STRING || type_ == REDIS_REPLY_TYPE_STATUS || type_ == REDIS_REPLY_TYPE_ERROR) only_set_string_data(str, len); else { free_data(); data_ = new Rstr(str, len); } type_ = REDIS_REPLY_TYPE_STRING; } void RedisValue::set_status(const char *str, size_t len) { if (type_ == REDIS_REPLY_TYPE_STRING || type_ == REDIS_REPLY_TYPE_STATUS || type_ == REDIS_REPLY_TYPE_ERROR) only_set_string_data(str, len); else { free_data(); data_ = new Rstr(str, len); } type_ = REDIS_REPLY_TYPE_STATUS; } void RedisValue::set_error(const char *str, size_t len) { if (type_ == REDIS_REPLY_TYPE_STRING || type_ == REDIS_REPLY_TYPE_STATUS || type_ == REDIS_REPLY_TYPE_ERROR) only_set_string_data(str, len); else { free_data(); data_ = new Rstr(str, len); } type_ = REDIS_REPLY_TYPE_ERROR; } void RedisValue::set_string(const char *str) { if (type_ == REDIS_REPLY_TYPE_STRING || type_ == REDIS_REPLY_TYPE_STATUS || type_ == REDIS_REPLY_TYPE_ERROR) only_set_string_data(str); else { free_data(); data_ = new Rstr(str); } type_ = REDIS_REPLY_TYPE_STRING; } void RedisValue::set_status(const char *str) { if (type_ == REDIS_REPLY_TYPE_STRING || type_ == REDIS_REPLY_TYPE_STATUS || type_ == REDIS_REPLY_TYPE_ERROR) only_set_string_data(str); else { free_data(); data_ = new Rstr(str); } type_ = REDIS_REPLY_TYPE_STATUS; } void RedisValue::set_error(const char *str) { if (type_ == REDIS_REPLY_TYPE_STRING || type_ == REDIS_REPLY_TYPE_STATUS || type_ == REDIS_REPLY_TYPE_ERROR) only_set_string_data(str); else { free_data(); data_ = new Rstr(str); } type_ = REDIS_REPLY_TYPE_ERROR; } void RedisValue::set_array(size_t new_size) { if (type_ == REDIS_REPLY_TYPE_ARRAY) ((Rarr *)data_)->resize(new_size); else { free_data(); data_ = new Rarr(new_size); type_ = REDIS_REPLY_TYPE_ARRAY; } } void RedisValue::set(const redis_reply_t *reply) { set_nil(); switch (reply->type) { case REDIS_REPLY_TYPE_INTEGER: set_int(reply->integer); break; case REDIS_REPLY_TYPE_ERROR: set_error(reply->str, reply->len); break; case REDIS_REPLY_TYPE_STATUS: set_status(reply->str, reply->len); break; case REDIS_REPLY_TYPE_STRING: set_string(reply->str, reply->len); break; case REDIS_REPLY_TYPE_ARRAY: set_array(reply->elements); if (reply->elements > 0) { Rarr *parr = (Rarr *)data_; for (size_t i = 0; i < reply->elements; i++) (*parr)[i].set(reply->element[i]); } break; } } void RedisValue::arr_clear() { if (type_ == REDIS_REPLY_TYPE_ARRAY) ((Rarr *)data_)->clear(); } void RedisValue::arr_resize(size_t new_size) { if (type_ == REDIS_REPLY_TYPE_ARRAY) ((Rarr *)data_)->resize(new_size); } bool RedisValue::transform(redis_reply_t *reply) const { //todo risk of stack overflow Rarr *parr; Rstr *pstr; redis_reply_set_null(reply); switch (type_) { case REDIS_REPLY_TYPE_INTEGER: redis_reply_set_integer(*((Rint *)data_), reply); break; case REDIS_REPLY_TYPE_ARRAY: parr = (Rarr *)data_; if (redis_reply_set_array(parr->size(), reply) < 0) return false; for (size_t i = 0; i < reply->elements; i++) { if (!(*parr)[i].transform(reply->element[i])) return false; } break; case REDIS_REPLY_TYPE_STATUS: pstr = (Rstr *)data_; redis_reply_set_status(pstr->c_str(), pstr->size(), reply); break; case REDIS_REPLY_TYPE_ERROR: pstr = (Rstr *)data_; redis_reply_set_error(pstr->c_str(), pstr->size(), reply); break; case REDIS_REPLY_TYPE_STRING: pstr = (Rstr *)data_; redis_reply_set_string(pstr->c_str(), pstr->size(), reply); break; } return true; } std::string RedisValue::debug_string() const { std::string ret; if (is_error()) { ret += "ERROR: "; ret += string_view()->c_str(); } else if (is_int()) { std::ostringstream oss; oss << int_value(); ret += oss.str(); } else if (is_nil()) { ret += "nil"; } else if (is_string()) { ret += '\"'; ret += string_view()->c_str(); ret += '\"'; } else if (is_array()) { ret += '['; size_t l = arr_size(); for (size_t i = 0; i < l; i++) { if (i) ret += ", "; ret += (*this)[i].debug_string(); } ret += ']'; } return ret; } RedisValue::~RedisValue() { free_data(); } RedisMessage::RedisMessage(): parser_(new redis_parser_t), stream_(new EncodeStream), cur_size_(0), asking_(false) { redis_parser_init(parser_); } RedisMessage::~RedisMessage() { if (parser_) { redis_parser_deinit(parser_); delete parser_; delete stream_; } } RedisMessage::RedisMessage(RedisMessage&& move) : ProtocolMessage(std::move(move)) { parser_ = move.parser_; stream_ = move.stream_; cur_size_ = move.cur_size_; asking_ = move.asking_; move.parser_ = NULL; move.stream_ = NULL; move.cur_size_ = 0; move.asking_ = false; } RedisMessage& RedisMessage::operator= (RedisMessage &&move) { if (this != &move) { *(ProtocolMessage *)this = std::move(move); if (parser_) { redis_parser_deinit(parser_); delete parser_; delete stream_; } parser_ = move.parser_; stream_ = move.stream_; cur_size_ = move.cur_size_; asking_ = move.asking_; move.parser_ = NULL; move.stream_ = NULL; move.cur_size_ = 0; move.asking_ = false; } return *this; } bool RedisMessage::encode_reply(redis_reply_t *reply) { EncodeStream& stream = *stream_; switch (reply->type) { case REDIS_REPLY_TYPE_STATUS: stream << "+" << std::make_pair(reply->str, reply->len) << "\r\n"; break; case REDIS_REPLY_TYPE_ERROR: stream << "-" << std::make_pair(reply->str, reply->len) << "\r\n"; break; case REDIS_REPLY_TYPE_NIL: stream << "$-1\r\n"; break; case REDIS_REPLY_TYPE_INTEGER: stream << ":" << reply->integer << "\r\n"; break; case REDIS_REPLY_TYPE_STRING: stream << "$" << reply->len << "\r\n"; stream << std::make_pair(reply->str, reply->len) << "\r\n"; break; case REDIS_REPLY_TYPE_ARRAY: stream << "*" << reply->elements << "\r\n"; for (size_t i = 0; i < reply->elements; i++) if (!encode_reply(reply->element[i])) return false; break; default: return false; } return true; } int RedisMessage::encode(struct iovec vectors[], int max) { stream_->reset(vectors, max); if (encode_reply(&parser_->reply)) return stream_->size(); return 0; } int RedisMessage::append(const void *buf, size_t *size) { int ret = redis_parser_append_message(buf, size, parser_); if (ret >= 0) { cur_size_ += *size; if (cur_size_ > this->size_limit) { errno = EMSGSIZE; ret = -1; } } else if (ret == -2) { errno = EBADMSG; ret = -1; } return ret; } void RedisRequest::set_request(const std::string& command, const std::vector& params) { size_t n = params.size() + 1; user_request_.reserve(n); user_request_.clear(); user_request_.push_back(command); for (size_t i = 0; i < params.size(); i++) user_request_.push_back(params[i]); redis_reply_t *reply = &parser_->reply; redis_reply_set_array(n, reply); for (size_t i = 0; i < n; i++) { redis_reply_set_string(user_request_[i].c_str(), user_request_[i].size(), reply->element[i]); } } bool RedisRequest::get_command(std::string& command) const { const redis_reply_t *reply = &parser_->reply; if (reply->type == REDIS_REPLY_TYPE_ARRAY && reply->elements > 0) { reply = reply->element[0]; if (reply->type == REDIS_REPLY_TYPE_STRING) { command.assign(reply->str, reply->len); return true; } } return false; } bool RedisRequest::get_params(std::vector& params) const { const redis_reply_t *reply = &parser_->reply; if (reply->type == REDIS_REPLY_TYPE_ARRAY && reply->elements > 0) { for (size_t i = 1; i < reply->elements; i++) { if (reply->element[i]->type != REDIS_REPLY_TYPE_STRING && reply->element[i]->type != REDIS_REPLY_TYPE_NIL) { return false; } } params.reserve(reply->elements - 1); params.clear(); for (size_t i = 1; i < reply->elements; i++) params.emplace_back(reply->element[i]->str, reply->element[i]->len); return true; } return false; } #define REDIS_ASK_COMMAND "ASKING" #define REDIS_ASK_REQUEST "*1\r\n$6\r\nASKING\r\n" #define REDIS_OK_RESPONSE "+OK\r\n" int RedisRequest::encode(struct iovec vectors[], int max) { stream_->reset(vectors, max); if (is_asking()) (*stream_) << REDIS_ASK_REQUEST; if (encode_reply(&parser_->reply)) return stream_->size(); return 0; } int RedisRequest::append(const void *buf, size_t *size) { int ret = RedisMessage::append(buf, size); if (ret > 0) { std::string command; if (get_command(command) && strcasecmp(command.c_str(), REDIS_ASK_COMMAND) == 0) { redis_parser_deinit(parser_); redis_parser_init(parser_); set_asking(true); ret = this->feedback(REDIS_OK_RESPONSE, strlen(REDIS_OK_RESPONSE)); if (ret != strlen(REDIS_OK_RESPONSE)) { errno = ENOBUFS; ret = -1; } else ret = 0; } } return ret; } int RedisResponse::append(const void *buf, size_t *size) { int ret = RedisMessage::append(buf, size); if (ret > 0 && is_asking()) { redis_parser_deinit(parser_); redis_parser_init(parser_); ret = 0; set_asking(false); } return ret; } bool RedisResponse::set_result(const RedisValue& value) { redis_reply_t *reply = &parser_->reply; redis_reply_deinit(reply); redis_reply_init(reply); value_ = value; return value_.transform(reply); } } workflow-0.11.8/src/protocol/RedisMessage.h000066400000000000000000000200301476003635400206320ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) Liu Kai (liukaidx@sogou-inc.com) */ #ifndef _REDISMESSAGE_H_ #define _REDISMESSAGE_H_ #include #include #include #include "ProtocolMessage.h" #include "redis_parser.h" /** * @file RedisMessage.h * @brief Redis Protocol Interface */ namespace protocol { class RedisValue { public: // nil RedisValue(); virtual ~RedisValue(); //copy constructor RedisValue(const RedisValue& copy); //copy operator RedisValue& operator= (const RedisValue& copy); //move constructor RedisValue(RedisValue&& move); //move operator RedisValue& operator= (RedisValue&& move); // release memory and change type to nil void set_nil(); void set_int(int64_t intv); void set_string(const std::string& strv); void set_status(const std::string& strv); void set_error(const std::string& strv); void set_string(const char *str, size_t len); void set_status(const char *str, size_t len); void set_error(const char *str, size_t len); void set_string(const char *str); void set_status(const char *str); void set_error(const char *str); // array(resize) void set_array(size_t new_size); // set data by C style data struct void set(const redis_reply_t *reply); // Return true if not error bool is_ok() const; // Return true if error bool is_error() const; // Return true if nil bool is_nil() const; // Return true if integer bool is_int() const; // Return true if array bool is_array() const; // Return true if string/status bool is_string() const; // Return type of C style data struct int get_type() const; // Copy. If type isnot string/status/error, returns an empty std::string std::string string_value() const; // No copy. If type isnot string/status/error, returns NULL. const std::string *string_view() const; // If type isnot integer, returns 0 int64_t int_value() const; // If type isnot array, returns 0 size_t arr_size() const; // If type isnot array, do nothing void arr_clear(); // If type isnot array, do nothing void arr_resize(size_t new_size); // Always return std::vector.at(pos); notice overflow exception RedisValue& arr_at(size_t pos) const; // Always return std::vector[pos]; notice overflow exception RedisValue& operator[] (size_t pos) const; // transform data into C style data struct bool transform(redis_reply_t *reply) const; // equal to set_nil(); void clear(); // format data to text std::string debug_string() const; private: void free_data(); void only_set_string_data(const std::string& strv); void only_set_string_data(const char *str, size_t len); void only_set_string_data(const char *str); int type_; void *data_; }; class RedisMessage : public ProtocolMessage { public: RedisMessage(); virtual ~RedisMessage(); //move constructor RedisMessage(RedisMessage&& move); //move operator RedisMessage& operator= (RedisMessage&& move); public: //peek after CommMessageIn append //not for users. bool parse_success() const; bool is_asking() const; void set_asking(bool asking); protected: redis_parser_t *parser_; virtual int encode(struct iovec vectors[], int max); virtual int append(const void *buf, size_t *size); bool encode_reply(redis_reply_t *reply); class EncodeStream *stream_; private: size_t cur_size_; bool asking_; }; class RedisRequest : public RedisMessage { public: RedisRequest() = default; //move constructor RedisRequest(RedisRequest&& move) = default; //move operator RedisRequest& operator= (RedisRequest&& move) = default; public:// C++ style // Usually, client use set_request to (prepare)send request to server // Usually, server use get_command/get_params to get client request // set_request("HSET", {"keyname", "hashkey", "somevalue"}); void set_request(const std::string& command, const std::vector& params); bool get_command(std::string& command) const; bool get_params(std::vector& params) const; protected: virtual int encode(struct iovec vectors[], int max); virtual int append(const void *buf, size_t *size); private: std::vector user_request_; }; class RedisResponse : public RedisMessage { public: RedisResponse() = default; //move constructor RedisResponse(RedisResponse&& move) = default; //move operator RedisResponse& operator= (RedisResponse&& move) = default; public:// C++ style // client use get_result to get result from server, copy void get_result(RedisValue& value) const; // server use set_result to (prepare)send result to client, copy bool set_result(const RedisValue& value); public:// C style // redis_parser_t is absolutely same as hiredis-redisReply in memory // If you include hiredis.h, redisReply* can cast to redis_reply_t* safely // BUT this function return not a copy, DONOT free the pointer by yourself // client read data from redis_reply_t by pointer of result_ptr // server write data into redis_reply_t by pointer of result_ptr redis_reply_t *result_ptr(); protected: virtual int append(const void *buf, size_t *size); private: RedisValue value_; }; //////////////////// inline RedisValue::RedisValue(): type_(REDIS_REPLY_TYPE_NIL), data_(NULL) { } inline RedisValue::RedisValue(const RedisValue& copy): type_(REDIS_REPLY_TYPE_NIL), data_(NULL) { this->operator= (copy); } inline RedisValue::RedisValue(RedisValue&& move): type_(REDIS_REPLY_TYPE_NIL), data_(NULL) { this->operator= (std::move(move)); } inline bool RedisValue::is_ok() const { return type_ != REDIS_REPLY_TYPE_ERROR; } inline bool RedisValue::is_error() const { return type_ == REDIS_REPLY_TYPE_ERROR; } inline bool RedisValue::is_nil() const { return type_ == REDIS_REPLY_TYPE_NIL; } inline bool RedisValue::is_int() const { return type_ == REDIS_REPLY_TYPE_INTEGER; } inline bool RedisValue::is_array() const { return type_ == REDIS_REPLY_TYPE_ARRAY; } inline int RedisValue::get_type() const { return type_; } inline bool RedisValue::is_string() const { return type_ == REDIS_REPLY_TYPE_STRING || type_ == REDIS_REPLY_TYPE_STATUS; } inline std::string RedisValue::string_value() const { if (type_ == REDIS_REPLY_TYPE_STRING || type_ == REDIS_REPLY_TYPE_STATUS || type_ == REDIS_REPLY_TYPE_ERROR) return *((std::string *)data_); else return ""; } inline const std::string *RedisValue::string_view() const { if (type_ == REDIS_REPLY_TYPE_STRING || type_ == REDIS_REPLY_TYPE_STATUS || type_ == REDIS_REPLY_TYPE_ERROR) return ((std::string *)data_); else return NULL; } inline int64_t RedisValue::int_value() const { if (type_ == REDIS_REPLY_TYPE_INTEGER) return *((int64_t *)data_); else return 0; } inline size_t RedisValue::arr_size() const { if (type_ == REDIS_REPLY_TYPE_ARRAY) return ((std::vector *)data_)->size(); else return 0; } inline RedisValue& RedisValue::arr_at(size_t pos) const { return ((std::vector *)data_)->at(pos); } inline RedisValue& RedisValue::operator[] (size_t pos) const { return (*((std::vector *)data_))[pos]; } inline void RedisValue::set_nil() { free_data(); type_ = REDIS_REPLY_TYPE_NIL; } inline void RedisValue::clear() { set_nil(); } inline bool RedisMessage::parse_success() const { return parser_->parse_succ; } inline bool RedisMessage::is_asking() const { return asking_; } inline void RedisMessage::set_asking(bool asking) { asking_ = asking; } inline redis_reply_t *RedisResponse::result_ptr() { return &parser_->reply; } inline void RedisResponse::get_result(RedisValue& value) const { if (parser_->parse_succ) value.set(&parser_->reply); else value.set_nil(); } } #endif workflow-0.11.8/src/protocol/SSLWrapper.cc000066400000000000000000000116151476003635400204300ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include "SSLWrapper.h" namespace protocol { #if OPENSSL_VERSION_NUMBER < 0x10100000L static inline BIO *__get_wbio(SSL *ssl) { BIO *wbio = SSL_get_wbio(ssl); BIO *next = BIO_next(wbio); return next ? next : wbio; } # define SSL_get_wbio(ssl) __get_wbio(ssl) #endif int SSLHandshaker::encode(struct iovec vectors[], int max) { BIO *wbio = SSL_get_wbio(this->ssl); char *ptr; long len; int ret; ret = SSL_do_handshake(this->ssl); if (ret <= 0) { ret = SSL_get_error(this->ssl, ret); if (ret != SSL_ERROR_WANT_READ) { if (ret != SSL_ERROR_SYSCALL) errno = -ret; return -1; } } len = BIO_get_mem_data(wbio, &ptr); if (len > 0) { vectors[0].iov_base = ptr; vectors[0].iov_len = len; return 1; } else if (len == 0) return 0; else return -1; } static int __ssl_handshake(const void *buf, size_t *size, SSL *ssl, char **ptr, long *len) { BIO *wbio = SSL_get_wbio(ssl); BIO *rbio = SSL_get_rbio(ssl); int ret; ret = BIO_write(rbio, buf, *size); if (ret <= 0) return -1; *size = ret; ret = SSL_do_handshake(ssl); if (ret <= 0) { ret = SSL_get_error(ssl, ret); if (ret != SSL_ERROR_WANT_READ) { if (ret != SSL_ERROR_SYSCALL) errno = -ret; return -1; } ret = 0; } *len = BIO_get_mem_data(wbio, ptr); if (*len < 0) return -1; return ret; } int SSLHandshaker::append(const void *buf, size_t *size) { BIO *wbio = SSL_get_wbio(this->ssl); char *ptr; long len; long n; int ret; BIO_reset(wbio); ret = __ssl_handshake(buf, size, this->ssl, &ptr, &len); if (ret != 0) return ret; if (len > 0) { n = this->feedback(ptr, len); BIO_reset(wbio); } else n = 0; if (n == len) return ret; if (n >= 0) errno = ENOBUFS; return -1; } int SSLWrapper::encode(struct iovec vectors[], int max) { BIO *wbio = SSL_get_wbio(this->ssl); struct iovec *iov; char *ptr; long len; int ret; ret = this->ProtocolWrapper::encode(vectors, max); if ((unsigned int)ret > (unsigned int)max) return ret; max = ret; for (iov = vectors; iov < vectors + max; iov++) { if (iov->iov_len > 0) { ret = SSL_write(this->ssl, iov->iov_base, iov->iov_len); if (ret <= 0) { ret = SSL_get_error(this->ssl, ret); if (ret != SSL_ERROR_SYSCALL) errno = -ret; return -1; } } } len = BIO_get_mem_data(wbio, &ptr); if (len > 0) { vectors[0].iov_base = ptr; vectors[0].iov_len = len; return 1; } else if (len == 0) return 0; else return -1; } #define BUFSIZE 8192 int SSLWrapper::append_message() { char buf[BUFSIZE]; int ret; while ((ret = SSL_read(this->ssl, buf, BUFSIZE)) > 0) { size_t nleft = ret; char *p = buf; size_t n; do { n = nleft; ret = this->ProtocolWrapper::append(p, &n); if (ret == 0) { nleft -= n; p += n; } else return ret; } while (nleft > 0); } if (ret < 0) { ret = SSL_get_error(this->ssl, ret); if (ret != SSL_ERROR_WANT_READ) { if (ret != SSL_ERROR_SYSCALL) errno = -ret; return -1; } } return 0; } int SSLWrapper::append(const void *buf, size_t *size) { BIO *wbio = SSL_get_wbio(this->ssl); BIO *rbio = SSL_get_rbio(this->ssl); int ret; BIO_reset(wbio); ret = BIO_write(rbio, buf, *size); if (ret <= 0) return -1; *size = ret; return this->append_message(); } int SSLWrapper::feedback(const void *buf, size_t size) { BIO *wbio = SSL_get_wbio(this->ssl); char *ptr; long len; long n; int ret; if (size == 0) return 0; ret = SSL_write(this->ssl, buf, size); if (ret <= 0) { ret = SSL_get_error(this->ssl, ret); if (ret != SSL_ERROR_SYSCALL) errno = -ret; return -1; } len = BIO_get_mem_data(wbio, &ptr); if (len >= 0) { n = this->ProtocolWrapper::feedback(ptr, len); BIO_reset(wbio); if (n == len) return size; if (ret > 0) errno = ENOBUFS; } return -1; } int ServerSSLWrapper::append(const void *buf, size_t *size) { BIO *wbio = SSL_get_wbio(this->ssl); char *ptr; long len; long n; BIO_reset(wbio); if (__ssl_handshake(buf, size, this->ssl, &ptr, &len) < 0) return -1; if (len > 0) { n = this->ProtocolMessage::feedback(ptr, len); BIO_reset(wbio); } else n = 0; if (n == len) return this->append_message(); if (n >= 0) errno = ENOBUFS; return -1; } } workflow-0.11.8/src/protocol/SSLWrapper.h000066400000000000000000000037121476003635400202710ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _SSLWRAPPER_H_ #define _SSLWRAPPER_H_ #include #include "ProtocolMessage.h" namespace protocol { class SSLHandshaker : public ProtocolMessage { public: virtual int encode(struct iovec vectors[], int max); virtual int append(const void *buf, size_t *size); protected: SSL *ssl; public: SSLHandshaker(SSL *ssl) { this->ssl = ssl; } public: SSLHandshaker(SSLHandshaker&& handshaker) = default; SSLHandshaker& operator = (SSLHandshaker&& handshaker) = default; }; class SSLWrapper : public ProtocolWrapper { protected: virtual int encode(struct iovec vectors[], int max); virtual int append(const void *buf, size_t *size); protected: virtual int feedback(const void *buf, size_t size); protected: int append_message(); protected: SSL *ssl; public: SSLWrapper(ProtocolMessage *message, SSL *ssl) : ProtocolWrapper(message) { this->ssl = ssl; } public: SSLWrapper(SSLWrapper&& wrapper) = default; SSLWrapper& operator = (SSLWrapper&& wrapper) = default; }; class ServerSSLWrapper : public SSLWrapper { protected: virtual int append(const void *buf, size_t *size); public: ServerSSLWrapper(ProtocolMessage *message, SSL *ssl) : SSLWrapper(message, ssl) { } public: ServerSSLWrapper(ServerSSLWrapper&& wrapper) = default; ServerSSLWrapper& operator = (ServerSSLWrapper&& wrapper) = default; }; } #endif workflow-0.11.8/src/protocol/TLVMessage.cc000066400000000000000000000035741476003635400204050ustar00rootroot00000000000000/* Copyright (c) 2023 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include #include "TLVMessage.h" namespace protocol { int TLVMessage::encode(struct iovec vectors[], int max) { this->head[0] = htonl((uint32_t)this->type); this->head[1] = htonl(this->value.size()); vectors[0].iov_base = this->head; vectors[0].iov_len = 8; vectors[1].iov_base = (char *)this->value.data(); vectors[1].iov_len = this->value.size(); return 2; } int TLVMessage::append(const void *buf, size_t *size) { size_t n = *size; size_t head_left; head_left = 8 - this->head_received; if (head_left > 0) { void *p = (char *)this->head + this->head_received; if (n < head_left) { memcpy(p, buf, n); this->head_received += n; return 0; } memcpy(p, buf, head_left); this->head_received = 8; buf = (const char *)buf + head_left; n -= head_left; this->type = (int)ntohl(this->head[0]); *this->head = ntohl(this->head[1]); if (*this->head > this->size_limit) { errno = EMSGSIZE; return -1; } this->value.reserve(*this->head); } if (this->value.size() + n > *this->head) { n = *this->head - this->value.size(); *size = n + head_left; } this->value.append((const char *)buf, n); return this->value.size() == *this->head; } } workflow-0.11.8/src/protocol/TLVMessage.h000066400000000000000000000027571476003635400202510ustar00rootroot00000000000000/* Copyright (c) 2023 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _TLVMESSAGE_H_ #define _TLVMESSAGE_H_ #include #include #include #include "ProtocolMessage.h" namespace protocol { class TLVMessage : public ProtocolMessage { public: int get_type() const { return this->type; } void set_type(int type) { this->type = type; } std::string *get_value() { return &this->value; } void set_value(std::string value) { this->value = std::move(value); } protected: virtual int encode(struct iovec vectors[], int max); virtual int append(const void *buf, size_t *size); protected: int type; std::string value; private: uint32_t head[2]; size_t head_received; public: TLVMessage() { this->type = 0; this->head_received = 0; } public: TLVMessage(TLVMessage&& msg) = default; TLVMessage& operator = (TLVMessage&& msg) = default; }; using TLVRequest = TLVMessage; using TLVResponse = TLVMessage; } #endif workflow-0.11.8/src/protocol/dns_parser.c000066400000000000000000000667701476003635400204370ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Liu Kai (liukaidx@sogou-inc.com) */ #include #include #include #include "dns_parser.h" #define DNS_LABELS_MAX 63 #define DNS_NAMES_MAX 256 #define DNS_MSGBASE_INIT_SIZE 514 // 512 + 2(leading length) #define MAX(x, y) ((x) <= (y) ? (y) : (x)) struct __dns_record_entry { struct list_head entry_list; struct dns_record record; }; static inline uint8_t __dns_parser_uint8(const char *ptr) { return (unsigned char)ptr[0]; } static inline uint16_t __dns_parser_uint16(const char *ptr) { const unsigned char *p = (const unsigned char *)ptr; return ((uint16_t)p[0] << 8) + ((uint16_t)p[1]); } static inline uint32_t __dns_parser_uint32(const char *ptr) { const unsigned char *p = (const unsigned char *)ptr; return ((uint32_t)p[0] << 24) + ((uint32_t)p[1] << 16) + ((uint32_t)p[2] << 8) + ((uint32_t)p[3]); } /* * Parse a single . * is a domain name represented as a series of labels, and * terminated by a label with zero length. * * phost must point to an char array with at least DNS_NAMES_MAX+1 size */ static int __dns_parser_parse_host(char *phost, dns_parser_t *parser) { uint8_t len; uint16_t pointer; size_t hcur; const char *msgend; const char **cur; const char *curbackup; // backup cur when host label is pointer msgend = (const char *)parser->msgbuf + parser->msgsize; cur = &(parser->cur); curbackup = NULL; hcur = 0; if (*cur >= msgend) return -2; while (*cur < msgend) { len = __dns_parser_uint8(*cur); if ((len & 0xC0) == 0) { (*cur)++; if (len == 0) break; if (len > DNS_LABELS_MAX || *cur + len > msgend || hcur + len + 1 > DNS_NAMES_MAX) return -2; memcpy(phost + hcur, *cur, len); *cur += len; hcur += len; phost[hcur++] = '.'; } // RFC 1035, 4.1.4 Message compression else if ((len & 0xC0) == 0xC0) { pointer = __dns_parser_uint16(*cur) & 0x3FFF; if (pointer >= parser->msgsize) return -2; // pointer must point to a prior position if ((const char *)parser->msgbase + pointer >= *cur) return -2; *cur += 2; // backup cur only when the first pointer occurs if (curbackup == NULL) curbackup = *cur; *cur = (const char *)parser->msgbase + pointer; } else return -2; } if (curbackup != NULL) *cur = curbackup; if (hcur > 1 && phost[hcur - 1] == '.') hcur--; if (hcur == 0) phost[hcur++] = '.'; phost[hcur++] = '\0'; return 0; } static void __dns_parser_free_record(struct __dns_record_entry *r) { switch (r->record.type) { case DNS_TYPE_SOA: { struct dns_record_soa *soa; soa = (struct dns_record_soa *)(r->record.rdata); free(soa->mname); free(soa->rname); break; } case DNS_TYPE_SRV: { struct dns_record_srv *srv; srv = (struct dns_record_srv *)(r->record.rdata); free(srv->target); break; } case DNS_TYPE_MX: { struct dns_record_mx *mx; mx = (struct dns_record_mx *)(r->record.rdata); free(mx->exchange); break; } } free(r->record.name); free(r); } static void __dns_parser_free_record_list(struct list_head *head) { struct list_head *pos, *tmp; struct __dns_record_entry *entry; list_for_each_safe(pos, tmp, head) { entry = list_entry(pos, struct __dns_record_entry, entry_list); list_del(pos); __dns_parser_free_record(entry); } } /* * A RDATA format, from RFC 1035: * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | ADDRESS | * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * * ADDRESS: A 32 bit Internet address. * Hosts that have multiple Internet addresses will have multiple A records. */ static int __dns_parser_parse_a(struct __dns_record_entry **r, uint16_t rdlength, dns_parser_t *parser) { const char **cur; struct __dns_record_entry *entry; size_t entry_size; if (sizeof (struct in_addr) != rdlength) return -2; cur = &(parser->cur); entry_size = sizeof (struct __dns_record_entry) + sizeof (struct in_addr); entry = (struct __dns_record_entry *)malloc(entry_size); if (!entry) return -1; entry->record.rdata = (void *)(entry + 1); memcpy(entry->record.rdata, *cur, rdlength); *cur += rdlength; *r = entry; return 0; } /* * AAAA RDATA format, from RFC 3596: * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | ADDRESS | * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * * ADDRESS: A 128 bit Internet address. * Hosts that have multiple addresses will have multiple AAAA records. */ static int __dns_parser_parse_aaaa(struct __dns_record_entry **r, uint16_t rdlength, dns_parser_t *parser) { const char **cur; struct __dns_record_entry *entry; size_t entry_size; if (sizeof (struct in6_addr) != rdlength) return -2; cur = &(parser->cur); entry_size = sizeof (struct __dns_record_entry) + sizeof (struct in6_addr); entry = (struct __dns_record_entry *)malloc(entry_size); if (!entry) return -1; entry->record.rdata = (void *)(entry + 1); memcpy(entry->record.rdata, *cur, rdlength); *cur += rdlength; *r = entry; return 0; } /* * Parse any record. */ static int __dns_parser_parse_names(struct __dns_record_entry **r, uint16_t rdlength, dns_parser_t *parser) { const char *rcdend; const char **cur; struct __dns_record_entry *entry; size_t entry_size; size_t name_len; char name[DNS_NAMES_MAX + 2]; int ret; cur = &(parser->cur); rcdend = *cur + rdlength; ret = __dns_parser_parse_host(name, parser); if (ret < 0) return ret; if (*cur != rcdend) return -2; name_len = strlen(name); entry_size = sizeof (struct __dns_record_entry) + name_len + 1; entry = (struct __dns_record_entry *)malloc(entry_size); if (!entry) return -1; entry->record.rdata = (void *)(entry + 1); memcpy(entry->record.rdata, name, name_len + 1); *r = entry; return 0; } /* * SOA RDATA format, from RFC 1035: * 1 1 1 1 1 1 * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * / MNAME / * / / * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * / RNAME / * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | SERIAL | * | | * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | REFRESH | * | | * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | RETRY | * | | * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | EXPIRE | * | | * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | MINIMUM | * | | * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * * MNAME: * RNAME: * SERIAL: The unsigned 32 bit version number. * REFRESH: A 32 bit time interval. * RETRY: A 32 bit time interval. * EXPIRE: A 32 bit time value. * MINIMUM: The unsigned 32 bit integer. */ static int __dns_parser_parse_soa(struct __dns_record_entry **r, uint16_t rdlength, dns_parser_t *parser) { const char *rcdend; const char **cur; struct __dns_record_entry *entry; struct dns_record_soa *soa; size_t entry_size; char mname[DNS_NAMES_MAX + 2]; char rname[DNS_NAMES_MAX + 2]; int ret; cur = &(parser->cur); rcdend = *cur + rdlength; ret = __dns_parser_parse_host(mname, parser); if (ret < 0) return ret; ret = __dns_parser_parse_host(rname, parser); if (ret < 0) return ret; if (*cur + 20 != rcdend) return -2; entry_size = sizeof (struct __dns_record_entry) + sizeof (struct dns_record_soa); entry = (struct __dns_record_entry *)malloc(entry_size); if (!entry) return -1; entry->record.rdata = (void *)(entry + 1); soa = (struct dns_record_soa *)(entry->record.rdata); soa->mname = strdup(mname); soa->rname = strdup(rname); soa->serial = __dns_parser_uint32(*cur); soa->refresh = __dns_parser_uint32(*cur + 4); soa->retry = __dns_parser_uint32(*cur + 8); soa->expire = __dns_parser_uint32(*cur + 12); soa->minimum = __dns_parser_uint32(*cur + 16); if (!soa->mname || !soa->rname) { free(soa->mname); free(soa->rname); free(entry); return -1; } *cur += 20; *r = entry; return 0; } /* * SRV RDATA format, from RFC 2782: * 1 1 1 1 1 1 * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | PRIORITY | * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | WEIGHT | * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | PORT | * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * / TARGET / * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * * PRIORITY: A 16 bit unsigned integer in network byte order. * WEIGHT: A 16 bit unsigned integer in network byte order. * PORT: A 16 bit unsigned integer in network byte order. * TARGET: */ static int __dns_parser_parse_srv(struct __dns_record_entry **r, uint16_t rdlength, dns_parser_t *parser) { const char *rcdend; const char **cur; struct __dns_record_entry *entry; struct dns_record_srv *srv; size_t entry_size; char target[DNS_NAMES_MAX + 2]; uint16_t priority; uint16_t weight; uint16_t port; int ret; cur = &(parser->cur); rcdend = *cur + rdlength; if (*cur + 6 > rcdend) return -2; priority = __dns_parser_uint16(*cur); weight = __dns_parser_uint16(*cur + 2); port = __dns_parser_uint16(*cur + 4); *cur += 6; ret = __dns_parser_parse_host(target, parser); if (ret < 0) return ret; if (*cur != rcdend) return -2; entry_size = sizeof (struct __dns_record_entry) + sizeof (struct dns_record_srv); entry = (struct __dns_record_entry *)malloc(entry_size); if (!entry) return -1; entry->record.rdata = (void *)(entry + 1); srv = (struct dns_record_srv *)(entry->record.rdata); srv->priority = priority; srv->weight = weight; srv->port = port; srv->target = strdup(target); if (!srv->target) { free(entry); return -1; } *r = entry; return 0; } static int __dns_parser_parse_mx(struct __dns_record_entry **r, uint16_t rdlength, dns_parser_t *parser) { const char *rcdend; const char **cur; struct __dns_record_entry *entry; struct dns_record_mx *mx; size_t entry_size; char exchange[DNS_NAMES_MAX + 2]; int16_t preference; int ret; cur = &(parser->cur); rcdend = *cur + rdlength; if (*cur + 2 > rcdend) return -2; preference = __dns_parser_uint16(*cur); *cur += 2; ret = __dns_parser_parse_host(exchange, parser); if (ret < 0) return ret; if (*cur != rcdend) return -2; entry_size = sizeof (struct __dns_record_entry) + sizeof (struct dns_record_mx); entry = (struct __dns_record_entry *)malloc(entry_size); if (!entry) return -1; entry->record.rdata = (void *)(entry + 1); mx = (struct dns_record_mx *)(entry->record.rdata); mx->exchange = strdup(exchange); mx->preference = preference; if (!mx->exchange) { free(entry); return -1; } *r = entry; return 0; } static int __dns_parser_parse_others(struct __dns_record_entry **r, uint16_t rdlength, dns_parser_t *parser) { const char **cur; struct __dns_record_entry *entry; size_t entry_size; cur = &(parser->cur); entry_size = sizeof (struct __dns_record_entry) + rdlength; entry = (struct __dns_record_entry *)malloc(entry_size); if (!entry) return -1; entry->record.rdata = (void *)(entry + 1); memcpy(entry->record.rdata, *cur, rdlength); *cur += rdlength; *r = entry; return 0; } /* * RR format, from RFC 1035: * 1 1 1 1 1 1 * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | | * / NAME / * / / * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | TYPE | * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | CLASS | * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | TTL | * | | * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | RDLENGTH | * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * / RDATA / * / / * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * */ static int __dns_parser_parse_record(int idx, dns_parser_t *parser) { uint16_t i; uint16_t type; uint16_t rclass; uint32_t ttl; uint16_t rdlength; uint16_t count; const char *msgend; const char **cur; int ret; struct __dns_record_entry *entry; char host[DNS_NAMES_MAX + 2]; struct list_head *list; switch (idx) { case 0: count = parser->header.ancount; list = &parser->answer_list; break; case 1: count = parser->header.nscount; list = &parser->authority_list; break; case 2: count = parser->header.arcount; list = &parser->additional_list; break; default: return -2; } msgend = (const char *)parser->msgbuf + parser->msgsize; cur = &(parser->cur); for (i = 0; i < count; i++) { ret = __dns_parser_parse_host(host, parser); if (ret < 0) return ret; // TYPE(2) + CLASS(2) + TTL(4) + RDLENGTH(2) = 10 if (*cur + 10 > msgend) return -2; type = __dns_parser_uint16(*cur); rclass = __dns_parser_uint16(*cur + 2); ttl = __dns_parser_uint32(*cur + 4); rdlength = __dns_parser_uint16(*cur + 8); *cur += 10; if (*cur + rdlength > msgend) return -2; entry = NULL; switch (type) { case DNS_TYPE_A: ret = __dns_parser_parse_a(&entry, rdlength, parser); break; case DNS_TYPE_AAAA: ret = __dns_parser_parse_aaaa(&entry, rdlength, parser); break; case DNS_TYPE_NS: case DNS_TYPE_CNAME: case DNS_TYPE_PTR: ret = __dns_parser_parse_names(&entry, rdlength, parser); break; case DNS_TYPE_SOA: ret = __dns_parser_parse_soa(&entry, rdlength, parser); break; case DNS_TYPE_SRV: ret = __dns_parser_parse_srv(&entry, rdlength, parser); break; case DNS_TYPE_MX: ret = __dns_parser_parse_mx(&entry, rdlength, parser); break; default: ret = __dns_parser_parse_others(&entry, rdlength, parser); } if (ret < 0) return ret; entry->record.name = strdup(host); if (!entry->record.name) { __dns_parser_free_record(entry); return -1; } entry->record.type = type; entry->record.rclass = rclass; entry->record.ttl = ttl; entry->record.rdlength = rdlength; list_add_tail(&entry->entry_list, list); } return 0; } /* * Question format, from RFC 1035: * 1 1 1 1 1 1 * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | | * / QNAME / * / / * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | QTYPE | * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * | QCLASS | * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ * * The query name is encoded as a series of labels, each represented * as a one-byte length (maximum 63) followed by the text of the * label. The list is terminated by a label of length zero (which can * be thought of as the root domain). */ static int __dns_parser_parse_question(dns_parser_t *parser) { uint16_t qtype; uint16_t qclass; const char *msgend; const char **cur; int ret; char host[DNS_NAMES_MAX + 2]; msgend = (const char *)parser->msgbuf + parser->msgsize; cur = &(parser->cur); // question count != 1 is an error if (parser->header.qdcount != 1) return -2; // parse qname ret = __dns_parser_parse_host(host, parser); if (ret < 0) return ret; // parse qtype and qclass if (*cur + 4 > msgend) return -2; qtype = __dns_parser_uint16(*cur); qclass = __dns_parser_uint16(*cur + 2); *cur += 4; if (parser->question.qname) free(parser->question.qname); parser->question.qname = strdup(host); if (parser->question.qname == NULL) return -1; parser->question.qtype = qtype; parser->question.qclass = qclass; return 0; } void dns_parser_init(dns_parser_t *parser) { parser->msgbuf = NULL; parser->msgbase = NULL; parser->cur = NULL; parser->msgsize = 0; parser->bufsize = 0; parser->complete = 0; parser->single_packet = 0; memset(&parser->header, 0, sizeof (struct dns_header)); memset(&parser->question, 0, sizeof (struct dns_question)); INIT_LIST_HEAD(&parser->answer_list); INIT_LIST_HEAD(&parser->authority_list); INIT_LIST_HEAD(&parser->additional_list); } int dns_parser_set_question(const char *name, uint16_t qtype, uint16_t qclass, dns_parser_t *parser) { int ret; ret = dns_parser_set_question_name(name, parser); if (ret < 0) return ret; parser->question.qtype = qtype; parser->question.qclass = qclass; parser->header.qdcount = 1; return 0; } int dns_parser_set_question_name(const char *name, dns_parser_t *parser) { char *newname; size_t len; len = strlen(name); newname = (char *)malloc(len + 1); if (!newname) return -1; memcpy(newname, name, len + 1); // Remove trailing dot, except name is "." if (len > 1 && newname[len - 1] == '.') newname[len - 1] = '\0'; if (parser->question.qname) free(parser->question.qname); parser->question.qname = newname; return 0; } void dns_parser_set_id(uint16_t id, dns_parser_t *parser) { parser->header.id = id; } int dns_parser_parse_all(dns_parser_t *parser) { struct dns_header *h; int ret; int i; parser->complete = 1; parser->cur = (const char *)parser->msgbase; h = &parser->header; if (parser->msgsize < sizeof (struct dns_header)) return -2; memcpy(h, parser->msgbase, sizeof (struct dns_header)); h->id = ntohs(h->id); h->qdcount = ntohs(h->qdcount); h->ancount = ntohs(h->ancount); h->nscount = ntohs(h->nscount); h->arcount = ntohs(h->arcount); parser->cur += sizeof (struct dns_header); ret = __dns_parser_parse_question(parser); if (ret < 0) return ret; for (i = 0; i < 3; i++) { ret = __dns_parser_parse_record(i, parser); if (ret < 0) return ret; } return 0; } int dns_parser_append_message(const void *buf, size_t *n, dns_parser_t *parser) { int ret; size_t total; size_t new_size; size_t msgsize_bak; void *new_buf; if (parser->complete) { *n = 0; return 1; } if (!parser->single_packet) { msgsize_bak = parser->msgsize; if (parser->msgsize + *n > parser->bufsize) { new_size = MAX(DNS_MSGBASE_INIT_SIZE, 2 * parser->bufsize); while (new_size < parser->msgsize + *n) new_size *= 2; new_buf = realloc(parser->msgbuf, new_size); if (!new_buf) return -1; parser->msgbuf = new_buf; parser->bufsize = new_size; } memcpy((char*)parser->msgbuf + parser->msgsize, buf, *n); parser->msgsize += *n; if (parser->msgsize < 2) return 0; total = __dns_parser_uint16((char*)parser->msgbuf); if (parser->msgsize < total + 2) return 0; *n = total + 2 - msgsize_bak; parser->msgsize = total + 2; parser->msgbase = (char*)parser->msgbuf + 2; } else { parser->msgbuf = malloc(*n); memcpy(parser->msgbuf, buf, *n); parser->msgbase = parser->msgbuf; parser->msgsize = *n; parser->bufsize = *n; } ret = dns_parser_parse_all(parser); if (ret < 0) return ret; return 1; } void dns_parser_deinit(dns_parser_t *parser) { free(parser->msgbuf); free(parser->question.qname); __dns_parser_free_record_list(&parser->answer_list); __dns_parser_free_record_list(&parser->authority_list); __dns_parser_free_record_list(&parser->additional_list); } int dns_record_cursor_next(struct dns_record **record, dns_record_cursor_t *cursor) { struct __dns_record_entry *e; if (cursor->next->next != cursor->head) { cursor->next = cursor->next->next; e = list_entry(cursor->next, struct __dns_record_entry, entry_list); *record = &e->record; return 0; } return 1; } int dns_record_cursor_find_cname(const char *name, const char **cname, dns_record_cursor_t *cursor) { struct __dns_record_entry *e; if (!name || !cname) return 1; cursor->next = cursor->head; while (cursor->next->next != cursor->head) { cursor->next = cursor->next->next; e = list_entry(cursor->next, struct __dns_record_entry, entry_list); if (e->record.type == DNS_TYPE_CNAME && strcasecmp(name, e->record.name) == 0) { *cname = (const char *)e->record.rdata; return 0; } } return 1; } int dns_add_raw_record(const char *name, uint16_t type, uint16_t rclass, uint32_t ttl, uint16_t rlen, const void *rdata, struct list_head *list) { struct __dns_record_entry *entry; size_t entry_size = sizeof (struct __dns_record_entry) + rlen; entry = (struct __dns_record_entry *)malloc(entry_size); if (!entry) return -1; entry->record.name = strdup(name); if (!entry->record.name) { free(entry); return -1; } entry->record.type = type; entry->record.rclass = rclass; entry->record.ttl = ttl; entry->record.rdlength = rlen; entry->record.rdata = (void *)(entry + 1); memcpy(entry->record.rdata, rdata, rlen); list_add_tail(&entry->entry_list, list); return 0; } int dns_add_str_record(const char *name, uint16_t type, uint16_t rclass, uint32_t ttl, const char *rdata, struct list_head *list) { size_t rlen = strlen(rdata); // record.rdlength has no meaning for parsed record types, ignore its // correctness, same for soa/srv/mx record return dns_add_raw_record(name, type, rclass, ttl, rlen+1, rdata, list); } int dns_add_soa_record(const char *name, uint16_t rclass, uint32_t ttl, const char *mname, const char *rname, uint32_t serial, int32_t refresh, int32_t retry, int32_t expire, uint32_t minimum, struct list_head *list) { struct __dns_record_entry *entry; struct dns_record_soa *soa; size_t entry_size; char *pname, *pmname, *prname; entry_size = sizeof (struct __dns_record_entry) + sizeof (struct dns_record_soa); entry = (struct __dns_record_entry *)malloc(entry_size); if (!entry) return -1; entry->record.rdata = (void *)(entry + 1); entry->record.rdlength = 0; soa = (struct dns_record_soa *)(entry->record.rdata); pname = strdup(name); pmname = strdup(mname); prname = strdup(rname); if (!pname || !pmname || !prname) { free(pname); free(pmname); free(prname); free(entry); return -1; } soa->mname = pmname; soa->rname = prname; soa->serial = serial; soa->refresh = refresh; soa->retry = retry; soa->expire = expire; soa->minimum = minimum; entry->record.name = pname; entry->record.type = DNS_TYPE_SOA; entry->record.rclass = rclass; entry->record.ttl = ttl; list_add_tail(&entry->entry_list, list); return 0; } int dns_add_srv_record(const char *name, uint16_t rclass, uint32_t ttl, uint16_t priority, uint16_t weight, uint16_t port, const char *target, struct list_head *list) { struct __dns_record_entry *entry; struct dns_record_srv *srv; size_t entry_size; char *pname, *ptarget; entry_size = sizeof (struct __dns_record_entry) + sizeof (struct dns_record_srv); entry = (struct __dns_record_entry *)malloc(entry_size); if (!entry) return -1; entry->record.rdata = (void *)(entry + 1); entry->record.rdlength = 0; srv = (struct dns_record_srv *)(entry->record.rdata); pname = strdup(name); ptarget = strdup(target); if (!pname || !ptarget) { free(pname); free(ptarget); free(entry); return -1; } srv->priority = priority; srv->weight = weight; srv->port = port; srv->target = ptarget; entry->record.name = pname; entry->record.type = DNS_TYPE_SRV; entry->record.rclass = rclass; entry->record.ttl = ttl; list_add_tail(&entry->entry_list, list); return 0; } int dns_add_mx_record(const char *name, uint16_t rclass, uint32_t ttl, int16_t preference, const char *exchange, struct list_head *list) { struct __dns_record_entry *entry; struct dns_record_mx *mx; size_t entry_size; char *pname, *pexchange; entry_size = sizeof (struct __dns_record_entry) + sizeof (struct dns_record_mx); entry = (struct __dns_record_entry *)malloc(entry_size); if (!entry) return -1; entry->record.rdata = (void *)(entry + 1); entry->record.rdlength = 0; mx = (struct dns_record_mx *)(entry->record.rdata); pname = strdup(name); pexchange = strdup(exchange); if (!pname || !pexchange) { free(pname); free(pexchange); free(entry); return -1; } mx->preference = preference; mx->exchange = pexchange; entry->record.name = pname; entry->record.type = DNS_TYPE_MX; entry->record.rclass = rclass; entry->record.ttl = ttl; list_add_tail(&entry->entry_list, list); return 0; } const char *dns_type2str(int type) { switch (type) { case DNS_TYPE_A: return "A"; case DNS_TYPE_NS: return "NS"; case DNS_TYPE_MD: return "MD"; case DNS_TYPE_MF: return "MF"; case DNS_TYPE_CNAME: return "CNAME"; case DNS_TYPE_SOA: return "SOA"; case DNS_TYPE_MB: return "MB"; case DNS_TYPE_MG: return "MG"; case DNS_TYPE_MR: return "MR"; case DNS_TYPE_NULL: return "NULL"; case DNS_TYPE_WKS: return "WKS"; case DNS_TYPE_PTR: return "PTR"; case DNS_TYPE_HINFO: return "HINFO"; case DNS_TYPE_MINFO: return "MINFO"; case DNS_TYPE_MX: return "MX"; case DNS_TYPE_AAAA: return "AAAA"; case DNS_TYPE_SRV: return "SRV"; case DNS_TYPE_TXT: return "TXT"; case DNS_TYPE_AXFR: return "AXFR"; case DNS_TYPE_MAILB: return "MAILB"; case DNS_TYPE_MAILA: return "MAILA"; case DNS_TYPE_ALL: return "ALL"; default: return "Unknown"; } } const char *dns_class2str(int dnsclass) { switch (dnsclass) { case DNS_CLASS_IN: return "IN"; case DNS_CLASS_CS: return "CS"; case DNS_CLASS_CH: return "CH"; case DNS_CLASS_HS: return "HS"; case DNS_CLASS_ALL: return "ALL"; default: return "Unknown"; } } const char *dns_opcode2str(int opcode) { switch (opcode) { case DNS_OPCODE_QUERY: return "QUERY"; case DNS_OPCODE_IQUERY: return "IQUERY"; case DNS_OPCODE_STATUS: return "STATUS"; default: return "Unknown"; } } const char *dns_rcode2str(int rcode) { switch (rcode) { case DNS_RCODE_NO_ERROR: return "NO_ERROR"; case DNS_RCODE_FORMAT_ERROR: return "FORMAT_ERROR"; case DNS_RCODE_SERVER_FAILURE: return "SERVER_FAILURE"; case DNS_RCODE_NAME_ERROR: return "NAME_ERROR"; case DNS_RCODE_NOT_IMPLEMENTED: return "NOT_IMPLEMENTED"; case DNS_RCODE_REFUSED: return "REFUSED"; default: return "Unknown"; } } workflow-0.11.8/src/protocol/dns_parser.h000066400000000000000000000133541476003635400204320ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Liu Kai (liukaidx@sogou-inc.com) */ #ifndef _DNS_PARSER_H_ #define _DNS_PARSER_H_ #include #include #include "list.h" enum { DNS_TYPE_A = 1, DNS_TYPE_NS, DNS_TYPE_MD, DNS_TYPE_MF, DNS_TYPE_CNAME, DNS_TYPE_SOA = 6, DNS_TYPE_MB, DNS_TYPE_MG, DNS_TYPE_MR, DNS_TYPE_NULL, DNS_TYPE_WKS = 11, DNS_TYPE_PTR, DNS_TYPE_HINFO, DNS_TYPE_MINFO, DNS_TYPE_MX, DNS_TYPE_TXT = 16, DNS_TYPE_AAAA = 28, DNS_TYPE_SRV = 33, DNS_TYPE_AXFR = 252, DNS_TYPE_MAILB = 253, DNS_TYPE_MAILA = 254, DNS_TYPE_ALL = 255 }; enum { DNS_CLASS_IN = 1, DNS_CLASS_CS, DNS_CLASS_CH, DNS_CLASS_HS, DNS_CLASS_ALL = 255 }; enum { DNS_OPCODE_QUERY = 0, DNS_OPCODE_IQUERY, DNS_OPCODE_STATUS, }; enum { DNS_RCODE_NO_ERROR = 0, DNS_RCODE_FORMAT_ERROR, DNS_RCODE_SERVER_FAILURE, DNS_RCODE_NAME_ERROR, DNS_RCODE_NOT_IMPLEMENTED, DNS_RCODE_REFUSED }; enum { DNS_ANSWER_SECTION = 1, DNS_AUTHORITY_SECTION = 2, DNS_ADDITIONAL_SECTION = 3, }; /** * dns_header_t is a struct to describe the header of a dns * request or response packet, but the byte order is not * transformed. */ #pragma pack(1) struct dns_header { uint16_t id; #if __BYTE_ORDER == __LITTLE_ENDIAN uint8_t rd : 1; uint8_t tc : 1; uint8_t aa : 1; uint8_t opcode : 4; uint8_t qr : 1; uint8_t rcode : 4; uint8_t z : 3; uint8_t ra : 1; #elif __BYTE_ORDER == __BIG_ENDIAN uint8_t qr : 1; uint8_t opcode : 4; uint8_t aa : 1; uint8_t tc : 1; uint8_t rd : 1; uint8_t ra : 1; uint8_t z : 3; uint8_t rcode : 4; #else # error "unknown byte order" #endif uint16_t qdcount; uint16_t ancount; uint16_t nscount; uint16_t arcount; }; #pragma pack() struct dns_question { char *qname; uint16_t qtype; uint16_t qclass; }; struct dns_record_soa { char *mname; char *rname; uint32_t serial; int32_t refresh; int32_t retry; int32_t expire; uint32_t minimum; }; struct dns_record_srv { uint16_t priority; uint16_t weight; uint16_t port; char *target; }; struct dns_record_mx { int16_t preference; char *exchange; }; struct dns_record { char *name; uint16_t type; uint16_t rclass; uint32_t ttl; uint16_t rdlength; void *rdata; }; typedef struct __dns_parser { void *msgbuf; // Message with leading length (TCP) void *msgbase; // Message without leading length const char *cur; // Current parser position size_t msgsize; size_t bufsize; char complete; // Whether parse completed char single_packet; // Response without leading length When UDP struct dns_header header; struct dns_question question; struct list_head answer_list; struct list_head authority_list; struct list_head additional_list; } dns_parser_t; typedef struct __dns_record_cursor { const struct list_head *head; const struct list_head *next; } dns_record_cursor_t; #ifdef __cplusplus extern "C" { #endif void dns_parser_init(dns_parser_t *parser); void dns_parser_set_id(uint16_t id, dns_parser_t *parser); int dns_parser_set_question(const char *name, uint16_t qtype, uint16_t qclass, dns_parser_t *parser); int dns_parser_set_question_name(const char *name, dns_parser_t *parser); int dns_parser_parse_all(dns_parser_t *parser); int dns_parser_append_message(const void *buf, size_t *n, dns_parser_t *parser); void dns_parser_deinit(dns_parser_t *parser); int dns_record_cursor_next(struct dns_record **record, dns_record_cursor_t *cursor); int dns_record_cursor_find_cname(const char *name, const char **cname, dns_record_cursor_t *cursor); int dns_add_raw_record(const char *name, uint16_t type, uint16_t rclass, uint32_t ttl, uint16_t rlen, const void *rdata, struct list_head *list); int dns_add_str_record(const char *name, uint16_t type, uint16_t rclass, uint32_t ttl, const char *rdata, struct list_head *list); int dns_add_soa_record(const char *name, uint16_t rclass, uint32_t ttl, const char *mname, const char *rname, uint32_t serial, int32_t refresh, int32_t retry, int32_t expire, uint32_t minimum, struct list_head *list); int dns_add_srv_record(const char *name, uint16_t rclass, uint32_t ttl, uint16_t priority, uint16_t weight, uint16_t port, const char *target, struct list_head *list); int dns_add_mx_record(const char *name, uint16_t rclass, uint32_t ttl, int16_t preference, const char *exchange, struct list_head *list); const char *dns_type2str(int type); const char *dns_class2str(int dnsclass); const char *dns_opcode2str(int opcode); const char *dns_rcode2str(int rcode); #ifdef __cplusplus } #endif static inline void dns_answer_cursor_init(dns_record_cursor_t *cursor, const dns_parser_t *parser) { cursor->head = &parser->answer_list; cursor->next = cursor->head; } static inline void dns_authority_cursor_init(dns_record_cursor_t *cursor, const dns_parser_t *parser) { cursor->head = &parser->authority_list; cursor->next = cursor->head; } static inline void dns_additional_cursor_init(dns_record_cursor_t *cursor, const dns_parser_t *parser) { cursor->head = &parser->additional_list; cursor->next = cursor->head; } static inline void dns_record_cursor_deinit(dns_record_cursor_t *cursor) { } #endif workflow-0.11.8/src/protocol/http_parser.c000066400000000000000000000415541476003635400206230ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include "list.h" #include "http_parser.h" #define MIN(x, y) ((x) <= (y) ? (x) : (y)) #define MAX(x, y) ((x) >= (y) ? (x) : (y)) #define HTTP_START_LINE_MAX 8192 #define HTTP_HEADER_VALUE_MAX 8192 #define HTTP_CHUNK_LINE_MAX 1024 #define HTTP_TRAILER_LINE_MAX 8192 #define HTTP_MSGBUF_INIT_SIZE 2048 enum { HPS_START_LINE, HPS_HEADER_NAME, HPS_HEADER_VALUE, HPS_HEADER_COMPLETE }; enum { CPS_CHUNK_DATA, CPS_TRAILER_PART, CPS_CHUNK_COMPLETE }; struct __header_line { struct list_head list; int name_len; int value_len; char *buf; }; static int __add_message_header(const void *name, size_t name_len, const void *value, size_t value_len, http_parser_t *parser) { size_t size = sizeof (struct __header_line) + name_len + value_len + 4; struct __header_line *line; line = (struct __header_line *)malloc(size); if (line) { line->buf = (char *)(line + 1); memcpy(line->buf, name, name_len); line->buf[name_len] = ':'; line->buf[name_len + 1] = ' '; memcpy(line->buf + name_len + 2, value, value_len); line->buf[name_len + 2 + value_len] = '\r'; line->buf[name_len + 2 + value_len + 1] = '\n'; line->name_len = name_len; line->value_len = value_len; list_add_tail(&line->list, &parser->header_list); return 0; } return -1; } static int __set_message_header(const void *name, size_t name_len, const void *value, size_t value_len, http_parser_t *parser) { struct __header_line *line; struct list_head *pos; char *buf; list_for_each(pos, &parser->header_list) { line = list_entry(pos, struct __header_line, list); if (line->name_len == name_len && strncasecmp(line->buf, name, name_len) == 0) { if (value_len > line->value_len) { buf = (char *)malloc(name_len + value_len + 4); if (!buf) return -1; if (line->buf != (char *)(line + 1)) free(line->buf); line->buf = buf; memcpy(buf, name, name_len); buf[name_len] = ':'; buf[name_len + 1] = ' '; } memcpy(line->buf + name_len + 2, value, value_len); line->buf[name_len + 2 + value_len] = '\r'; line->buf[name_len + 2 + value_len + 1] = '\n'; line->value_len = value_len; return 0; } } return __add_message_header(name, name_len, value, value_len, parser); } static int __match_request_line(const char *method, const char *uri, const char *version, http_parser_t *parser) { if (strcmp(version, "HTTP/1.0") == 0 || strncmp(version, "HTTP/0", 6) == 0) parser->keep_alive = 0; method = strdup(method); if (method) { uri = strdup(uri); if (uri) { version = strdup(version); if (version) { free(parser->method); free(parser->uri); free(parser->version); parser->method = (char *)method; parser->uri = (char *)uri; parser->version = (char *)version; return 0; } free((char *)uri); } free((char *)method); } return -1; } static int __match_status_line(const char *version, const char *code, const char *phrase, http_parser_t *parser) { if (strcmp(version, "HTTP/1.0") == 0 || strncmp(version, "HTTP/0", 6) == 0) parser->keep_alive = 0; if (*code == '1' || strcmp(code, "204") == 0 || strcmp(code, "304") == 0) parser->transfer_length = 0; version = strdup(version); if (version) { code = strdup(code); if (code) { phrase = strdup(phrase); if (phrase) { free(parser->version); free(parser->code); free(parser->phrase); parser->version = (char *)version; parser->code = (char *)code; parser->phrase = (char *)phrase; return 0; } free((char *)code); } free((char *)version); } return -1; } static void __check_message_header(const char *name, size_t name_len, const char *value, size_t value_len, http_parser_t *parser) { switch (name_len) { case 6: if (strncasecmp(name, "Expect", 6) == 0) { if (value_len == 12 && strncasecmp(value, "100-continue", 12) == 0) parser->expect_continue = 1; } break; case 10: if (strncasecmp(name, "Connection", 10) == 0) { parser->has_connection = 1; if (value_len == 10 && strncasecmp(value, "Keep-Alive", 10) == 0) parser->keep_alive = 1; else if (value_len == 5 && strncasecmp(value, "close", 5) == 0) parser->keep_alive = 0; } else if (strncasecmp(name, "Keep-Alive", 10) == 0) parser->has_keep_alive = 1; break; case 14: if (strncasecmp(name, "Content-Length", 14) == 0) { parser->has_content_length = 1; if (*value >= '0' && *value <= '9' && value_len <= 15) { char buf[16]; memcpy(buf, value, value_len); buf[value_len] = '\0'; parser->content_length = atol(buf); } } break; case 17: if (strncasecmp(name, "Transfer-Encoding", 17) == 0) { if (value_len != 8 || strncasecmp(value, "identity", 8) != 0) parser->chunked = 1; else parser->chunked = 0; } break; } } static int __parse_start_line(const char *ptr, size_t len, http_parser_t *parser) { char start_line[HTTP_START_LINE_MAX]; size_t min = MIN(HTTP_START_LINE_MAX, len); char *p1, *p2, *p3; size_t i; int ret; if (len >= 2 && ptr[0] == '\r' && ptr[1] == '\n') { parser->header_offset += 2; return 1; } for (i = 0; i < min; i++) { start_line[i] = ptr[i]; if (start_line[i] == '\r') { if (i == len - 1) return 0; if (ptr[i + 1] != '\n') return -2; start_line[i] = '\0'; p1 = start_line; p2 = strchr(p1, ' '); if (p2) *p2++ = '\0'; else return -2; p3 = strchr(p2, ' '); if (p3) *p3++ = '\0'; else return -2; if (parser->is_resp) ret = __match_status_line(p1, p2, p3, parser); else ret = __match_request_line(p1, p2, p3, parser); if (ret < 0) return -1; parser->header_offset += i + 2; parser->header_state = HPS_HEADER_NAME; return 1; } if (start_line[i] == 0) return -2; } if (i == HTTP_START_LINE_MAX) return -2; return 0; } static int __parse_header_name(const char *ptr, size_t len, http_parser_t *parser) { size_t min = MIN(HTTP_HEADER_NAME_MAX, len); size_t i; if (len >= 2 && ptr[0] == '\r' && ptr[1] == '\n') { parser->header_offset += 2; parser->header_state = HPS_HEADER_COMPLETE; return 1; } for (i = 0; i < min; i++) { if (ptr[i] == ':') { parser->namebuf[i] = '\0'; parser->header_offset += i + 1; parser->header_state = HPS_HEADER_VALUE; return 1; } if ((signed char)ptr[i] <= 0) return -2; parser->namebuf[i] = ptr[i]; } if (i == HTTP_HEADER_NAME_MAX) return -2; return 0; } static int __parse_header_value(const char *ptr, size_t len, http_parser_t *parser) { char header_value[HTTP_HEADER_VALUE_MAX]; const char *end = ptr + len; const char *begin = ptr; size_t i = 0; while (1) { while (1) { if (ptr == end) return 0; if (*ptr == ' ' || *ptr == '\t') ptr++; else break; } while (1) { if (i == HTTP_HEADER_VALUE_MAX) return -2; header_value[i] = *ptr++; if (ptr == end) return 0; if (header_value[i] == '\r') break; if ((signed char)header_value[i] <= 0) return -2; i++; } if (*ptr == '\n') ptr++; else return -2; if (ptr == end) return 0; while (i > 0) { if (header_value[i - 1] == ' ' || header_value[i - 1] == '\t') i--; else break; } if (*ptr != ' ' && *ptr != '\t') break; ptr++; header_value[i++] = ' '; } header_value[i] = '\0'; if (http_parser_add_header(parser->namebuf, strlen(parser->namebuf), header_value, i, parser) < 0) return -1; parser->header_offset += ptr - begin; parser->header_state = HPS_HEADER_NAME; return 1; } static int __parse_message_header(const void *message, size_t size, http_parser_t *parser) { const char *ptr; size_t len; int ret; do { ptr = (const char *)message + parser->header_offset; len = size - parser->header_offset; if (parser->header_state == HPS_START_LINE) ret = __parse_start_line(ptr, len, parser); else if (parser->header_state == HPS_HEADER_VALUE) ret = __parse_header_value(ptr, len, parser); else /* if (parser->header_state == HPS_HEADER_NAME) */ { ret = __parse_header_name(ptr, len, parser); if (parser->header_state == HPS_HEADER_COMPLETE) return 1; } } while (ret > 0); return ret; } #define CHUNK_SIZE_MAX (2 * 1024 * 1024 * 1024U - HTTP_CHUNK_LINE_MAX - 4) static int __parse_chunk_data(const char *ptr, size_t len, http_parser_t *parser) { char chunk_line[HTTP_CHUNK_LINE_MAX]; size_t min = MIN(HTTP_CHUNK_LINE_MAX, len); long chunk_size; char *end; size_t i; for (i = 0; i < min; i++) { chunk_line[i] = ptr[i]; if (chunk_line[i] == '\r') { if (i == len - 1) return 0; if (ptr[i + 1] != '\n') return -2; chunk_line[i] = '\0'; chunk_size = strtol(chunk_line, &end, 16); if (end == chunk_line) return -2; if (chunk_size == 0) { chunk_size = i + 2; parser->chunk_state = CPS_TRAILER_PART; } else if ((unsigned long)chunk_size < CHUNK_SIZE_MAX) { chunk_size += i + 4; if (len < (size_t)chunk_size) return 0; } else return -2; parser->chunk_offset += chunk_size; return 1; } } if (i == HTTP_CHUNK_LINE_MAX) return -2; return 0; } static int __parse_trailer_part(const char *ptr, size_t len, http_parser_t *parser) { size_t min = MIN(HTTP_TRAILER_LINE_MAX, len); size_t i; for (i = 0; i < min; i++) { if (ptr[i] == '\r') { if (i == len - 1) return 0; if (ptr[i + 1] != '\n') return -2; parser->chunk_offset += i + 2; if (i == 0) parser->chunk_state = CPS_CHUNK_COMPLETE; return 1; } } if (i == HTTP_TRAILER_LINE_MAX) return -2; return 0; } static int __parse_chunk(const void *message, size_t size, http_parser_t *parser) { const char *ptr; size_t len; int ret; do { ptr = (const char *)message + parser->chunk_offset; len = size - parser->chunk_offset; if (parser->chunk_state == CPS_CHUNK_DATA) ret = __parse_chunk_data(ptr, len, parser); else /* if (parser->chunk_state == CPS_TRAILER_PART) */ { ret = __parse_trailer_part(ptr, len, parser); if (parser->chunk_state == CPS_CHUNK_COMPLETE) return 1; } } while (ret > 0); return ret; } void http_parser_init(int is_resp, http_parser_t *parser) { parser->header_state = HPS_START_LINE; parser->header_offset = 0; parser->transfer_length = (size_t)-1; parser->content_length = is_resp ? (size_t)-1 : 0; parser->version = NULL; parser->method = NULL; parser->uri = NULL; parser->code = NULL; parser->phrase = NULL; INIT_LIST_HEAD(&parser->header_list); parser->msgbuf = NULL; parser->msgsize = 0; parser->bufsize = 0; parser->has_connection = 0; parser->has_content_length = 0; parser->has_keep_alive = 0; parser->expect_continue = 0; parser->keep_alive = 1; parser->chunked = 0; parser->complete = 0; parser->is_resp = is_resp; } int http_parser_append_message(const void *buf, size_t *n, http_parser_t *parser) { int ret; if (parser->complete) { *n = 0; return 1; } if (parser->msgsize + *n + 1 > parser->bufsize) { size_t new_size = MAX(HTTP_MSGBUF_INIT_SIZE, 2 * parser->bufsize); void *new_base; while (new_size < parser->msgsize + *n + 1) new_size *= 2; new_base = realloc(parser->msgbuf, new_size); if (!new_base) return -1; parser->msgbuf = new_base; parser->bufsize = new_size; } memcpy((char *)parser->msgbuf + parser->msgsize, buf, *n); parser->msgsize += *n; if (parser->header_state != HPS_HEADER_COMPLETE) { ret = __parse_message_header(parser->msgbuf, parser->msgsize, parser); if (ret <= 0) return ret; if (parser->chunked) { parser->chunk_offset = parser->header_offset; parser->chunk_state = CPS_CHUNK_DATA; } else if (parser->transfer_length == (size_t)-1) parser->transfer_length = parser->content_length; } if (parser->transfer_length != (size_t)-1) { size_t total = parser->header_offset + parser->transfer_length; if (parser->msgsize >= total) { *n -= parser->msgsize - total; parser->msgsize = total; parser->complete = 1; return 1; } return 0; } if (!parser->chunked) return 0; if (parser->chunk_state != CPS_CHUNK_COMPLETE) { ret = __parse_chunk(parser->msgbuf, parser->msgsize, parser); if (ret <= 0) return ret; } *n -= parser->msgsize - parser->chunk_offset; parser->msgsize = parser->chunk_offset; parser->complete = 1; return 1; } int http_parser_header_complete(const http_parser_t *parser) { return parser->header_state == HPS_HEADER_COMPLETE; } int http_parser_get_body(const void **body, size_t *size, const http_parser_t *parser) { if (parser->complete && parser->header_state == HPS_HEADER_COMPLETE) { *body = (char *)parser->msgbuf + parser->header_offset; *size = parser->msgsize - parser->header_offset; ((char *)parser->msgbuf)[parser->msgsize] = '\0'; return 0; } return 1; } int http_parser_set_method(const char *method, http_parser_t *parser) { method = strdup(method); if (method) { free(parser->method); parser->method = (char *)method; return 0; } return -1; } int http_parser_set_uri(const char *uri, http_parser_t *parser) { uri = strdup(uri); if (uri) { free(parser->uri); parser->uri = (char *)uri; return 0; } return -1; } int http_parser_set_version(const char *version, http_parser_t *parser) { version = strdup(version); if (version) { free(parser->version); parser->version = (char *)version; return 0; } return -1; } int http_parser_set_code(const char *code, http_parser_t *parser) { code = strdup(code); if (code) { free(parser->code); parser->code = (char *)code; return 0; } return -1; } int http_parser_set_phrase(const char *phrase, http_parser_t *parser) { phrase = strdup(phrase); if (phrase) { free(parser->phrase); parser->phrase = (char *)phrase; return 0; } return -1; } int http_parser_add_header(const void *name, size_t name_len, const void *value, size_t value_len, http_parser_t *parser) { if (__add_message_header(name, name_len, value, value_len, parser) >= 0) { __check_message_header((const char *)name, name_len, (const char *)value, value_len, parser); return 0; } return -1; } int http_parser_set_header(const void *name, size_t name_len, const void *value, size_t value_len, http_parser_t *parser) { if (__set_message_header(name, name_len, value, value_len, parser) >= 0) { __check_message_header((const char *)name, name_len, (const char *)value, value_len, parser); return 0; } return -1; } void http_parser_deinit(http_parser_t *parser) { struct __header_line *line; struct list_head *pos, *tmp; list_for_each_safe(pos, tmp, &parser->header_list) { line = list_entry(pos, struct __header_line, list); list_del(pos); if (line->buf != (char *)(line + 1)) free(line->buf); free(line); } free(parser->version); free(parser->method); free(parser->uri); free(parser->code); free(parser->phrase); free(parser->msgbuf); } int http_header_cursor_next(const void **name, size_t *name_len, const void **value, size_t *value_len, http_header_cursor_t *cursor) { struct __header_line *line; if (cursor->next->next != cursor->head) { cursor->next = cursor->next->next; line = list_entry(cursor->next, struct __header_line, list); *name = line->buf; *name_len = line->name_len; *value = line->buf + line->name_len + 2; *value_len = line->value_len; return 0; } return 1; } int http_header_cursor_find(const void *name, size_t name_len, const void **value, size_t *value_len, http_header_cursor_t *cursor) { struct __header_line *line; while (cursor->next->next != cursor->head) { cursor->next = cursor->next->next; line = list_entry(cursor->next, struct __header_line, list); if (line->name_len == name_len) { if (strncasecmp(line->buf, name, name_len) == 0) { *value = line->buf + name_len + 2; *value_len = line->value_len; return 0; } } } return 1; } int http_header_cursor_erase(http_header_cursor_t *cursor) { struct __header_line *line; if (cursor->next != cursor->head) { line = list_entry(cursor->next, struct __header_line, list); cursor->next = cursor->next->prev; list_del(&line->list); if (line->buf != (char *)(line + 1)) free(line->buf); free(line); return 0; } return 1; } workflow-0.11.8/src/protocol/http_parser.h000066400000000000000000000104071476003635400206210ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _HTTP_PARSER_H_ #define _HTTP_PARSER_H_ #include #include "list.h" #define HTTP_HEADER_NAME_MAX 64 typedef struct __http_parser { int header_state; int chunk_state; size_t header_offset; size_t chunk_offset; size_t content_length; size_t transfer_length; char *version; char *method; char *uri; char *code; char *phrase; struct list_head header_list; char namebuf[HTTP_HEADER_NAME_MAX]; void *msgbuf; size_t msgsize; size_t bufsize; char has_connection; char has_content_length; char has_keep_alive; char expect_continue; char keep_alive; char chunked; char complete; char is_resp; } http_parser_t; typedef struct __http_header_cursor { const struct list_head *head; const struct list_head *next; } http_header_cursor_t; #ifdef __cplusplus extern "C" { #endif void http_parser_init(int is_resp, http_parser_t *parser); int http_parser_append_message(const void *buf, size_t *n, http_parser_t *parser); int http_parser_get_body(const void **body, size_t *size, const http_parser_t *parser); int http_parser_header_complete(const http_parser_t *parser); int http_parser_set_method(const char *method, http_parser_t *parser); int http_parser_set_uri(const char *uri, http_parser_t *parser); int http_parser_set_version(const char *version, http_parser_t *parser); int http_parser_set_code(const char *code, http_parser_t *parser); int http_parser_set_phrase(const char *phrase, http_parser_t *parser); int http_parser_add_header(const void *name, size_t name_len, const void *value, size_t value_len, http_parser_t *parser); int http_parser_set_header(const void *name, size_t name_len, const void *value, size_t value_len, http_parser_t *parser); void http_parser_deinit(http_parser_t *parser); int http_header_cursor_next(const void **name, size_t *name_len, const void **value, size_t *value_len, http_header_cursor_t *cursor); int http_header_cursor_find(const void *name, size_t name_len, const void **value, size_t *value_len, http_header_cursor_t *cursor); int http_header_cursor_erase(http_header_cursor_t *cursor); #ifdef __cplusplus } #endif static inline const char *http_parser_get_method(const http_parser_t *parser) { return parser->method; } static inline const char *http_parser_get_uri(const http_parser_t *parser) { return parser->uri; } static inline const char *http_parser_get_version(const http_parser_t *parser) { return parser->version; } static inline const char *http_parser_get_code(const http_parser_t *parser) { return parser->code; } static inline const char *http_parser_get_phrase(const http_parser_t *parser) { return parser->phrase; } static inline int http_parser_chunked(const http_parser_t *parser) { return parser->chunked; } static inline int http_parser_keep_alive(const http_parser_t *parser) { return parser->keep_alive; } static inline int http_parser_has_connection(const http_parser_t *parser) { return parser->has_connection; } static inline int http_parser_has_content_length(const http_parser_t *parser) { return parser->has_content_length; } static inline int http_parser_has_keep_alive(const http_parser_t *parser) { return parser->has_keep_alive; } static inline void http_parser_close_message(http_parser_t *parser) { parser->complete = 1; } static inline void http_header_cursor_init(http_header_cursor_t *cursor, const http_parser_t *parser) { cursor->head = &parser->header_list; cursor->next = cursor->head; } static inline void http_header_cursor_rewind(http_header_cursor_t *cursor) { cursor->next = cursor->head; } static inline void http_header_cursor_deinit(http_header_cursor_t *cursor) { } #endif workflow-0.11.8/src/protocol/kafka_parser.c000066400000000000000000000701511476003635400207140ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wang Zhulei (wangzhulei@sogou-inc.com) */ #include #include #include #include #include #include #include #include #include "kafka_parser.h" static kafka_api_version_t kafka_api_version_queryable[] = { { Kafka_ApiVersions, 0, 0 } }; static kafka_api_version_t kafka_api_version_0_9_0[] = { { Kafka_Produce, 0, 1 }, { Kafka_Fetch, 0, 1 }, { Kafka_ListOffsets, 0, 0 }, { Kafka_Metadata, 0, 0 }, { Kafka_OffsetCommit, 0, 2 }, { Kafka_OffsetFetch, 0, 1 }, { Kafka_FindCoordinator, 0, 0 }, { Kafka_JoinGroup, 0, 0 }, { Kafka_Heartbeat, 0, 0 }, { Kafka_LeaveGroup, 0, 0 }, { Kafka_SyncGroup, 0, 0 }, { Kafka_DescribeGroups, 0, 0 }, { Kafka_ListGroups, 0, 0 } }; static kafka_api_version_t kafka_api_version_0_8_2[] = { { Kafka_Produce, 0, 0 }, { Kafka_Fetch, 0, 0 }, { Kafka_ListOffsets, 0, 0 }, { Kafka_Metadata, 0, 0 }, { Kafka_OffsetCommit, 0, 1 }, { Kafka_OffsetFetch, 0, 1 }, { Kafka_FindCoordinator, 0, 0 } }; static kafka_api_version_t kafka_api_version_0_8_1[] = { { Kafka_Produce, 0, 0 }, { Kafka_Fetch, 0, 0 }, { Kafka_ListOffsets, 0, 0 }, { Kafka_Metadata, 0, 0 }, { Kafka_OffsetCommit, 0, 1 }, { Kafka_OffsetFetch, 0, 0 } }; static kafka_api_version_t kafka_api_version_0_8_0[] = { { Kafka_Produce, 0, 0 }, { Kafka_Fetch, 0, 0 }, { Kafka_ListOffsets, 0, 0 }, { Kafka_Metadata, 0, 0 } }; static const struct kafka_feature_map { unsigned feature; kafka_api_version_t depends[Kafka_ApiNums]; } kafka_feature_map[] = { { .feature = KAFKA_FEATURE_MSGVER1, .depends = { { Kafka_Produce, 2, 2 }, { Kafka_Fetch, 2, 2 }, { Kafka_Unknown, 0, 0 }, }, }, { .feature = KAFKA_FEATURE_MSGVER2, .depends = { { Kafka_Produce, 3, 3 }, { Kafka_Fetch, 4, 4 }, { Kafka_Unknown, 0, 0 }, }, }, { .feature = KAFKA_FEATURE_APIVERSION, .depends = { { Kafka_ApiVersions, 0, 0 }, { Kafka_Unknown, 0, 0 }, }, }, { .feature = KAFKA_FEATURE_BROKER_GROUP_COORD, .depends = { { Kafka_FindCoordinator, 0, 0 }, { Kafka_Unknown, 0, 0 }, }, }, { .feature = KAFKA_FEATURE_BROKER_BALANCED_CONSUMER, .depends = { { Kafka_FindCoordinator, 0, 0 }, { Kafka_OffsetCommit, 1, 2 }, { Kafka_OffsetFetch, 1, 1 }, { Kafka_JoinGroup, 0, 0 }, { Kafka_SyncGroup, 0, 0 }, { Kafka_Heartbeat, 0, 0 }, { Kafka_LeaveGroup, 0, 0 }, { Kafka_Unknown, 0, 0 }, }, }, { .feature = KAFKA_FEATURE_THROTTLETIME, .depends = { { Kafka_Produce, 1, 2 }, { Kafka_Fetch, 1, 2 }, { Kafka_Unknown, 0, 0 }, }, }, { .feature = KAFKA_FEATURE_LZ4, .depends = { { Kafka_FindCoordinator, 0, 0 }, { Kafka_Unknown, 0, 0 }, }, }, { .feature = KAFKA_FEATURE_OFFSET_TIME, .depends = { { Kafka_ListOffsets, 1, 1 }, { Kafka_Unknown, 0, 0 }, } }, { .feature = KAFKA_FEATURE_ZSTD, .depends = { { Kafka_Produce, 7, 7 }, { Kafka_Fetch, 10, 10 }, { Kafka_Unknown, 0, 0 }, }, }, { .feature = KAFKA_FEATURE_SASL_GSSAPI, .depends = { { Kafka_JoinGroup, 0, 0}, { Kafka_Unknown, 0, 0 }, }, }, { .feature = KAFKA_FEATURE_SASL_HANDSHAKE, .depends = { { Kafka_SaslHandshake, 0, 0}, { Kafka_Unknown, 0, 0 }, }, }, { .feature = KAFKA_FEATURE_SASL_AUTH_REQ, .depends = { { Kafka_SaslHandshake, 1, 1}, { Kafka_SaslAuthenticate, 0, 0}, { Kafka_Unknown, 0, 0 }, }, }, { .feature = 0, }, }; static int kafka_get_legacy_api_version(const char *broker_version, kafka_api_version_t **api, size_t *api_cnt) { static const struct { const char *pfx; kafka_api_version_t *api; size_t api_cnt; } vermap[] = { { "0.9.0", kafka_api_version_0_9_0, sizeof(kafka_api_version_0_9_0) / sizeof(kafka_api_version_t) }, { "0.8.2", kafka_api_version_0_8_2, sizeof(kafka_api_version_0_8_2) / sizeof(kafka_api_version_t) }, { "0.8.1", kafka_api_version_0_8_1, sizeof(kafka_api_version_0_8_1) / sizeof(kafka_api_version_t) }, { "0.8.0", kafka_api_version_0_8_0, sizeof(kafka_api_version_0_8_0) / sizeof(kafka_api_version_t) }, { "0.7.", NULL, 0 }, { "0.6", NULL, 0 }, { "", kafka_api_version_queryable, 1 }, { NULL, NULL, 0 } }; int i; for (i = 0 ; vermap[i].pfx ; i++) { if (!strncmp(vermap[i].pfx, broker_version, strlen(vermap[i].pfx))) { if (!vermap[i].api) return -1; *api = vermap[i].api; *api_cnt = vermap[i].api_cnt; break; } } return 0; } int kafka_api_version_is_queryable(const char *broker_version, kafka_api_version_t **api, size_t *api_cnt) { return kafka_get_legacy_api_version(broker_version, api, api_cnt); } static int kafka_api_version_key_cmp(const void *_a, const void *_b) { const kafka_api_version_t *a = _a, *b = _b; if (a->api_key > b->api_key) return 1; else if (a->api_key == b->api_key) return 0; else return -1; } static int kafka_api_version_check(const kafka_api_version_t *apis, size_t api_cnt, const kafka_api_version_t *match) { const kafka_api_version_t *api; api = bsearch(match, apis, api_cnt, sizeof(*apis), kafka_api_version_key_cmp); if (!api) return 0; return match->min_ver <= api->max_ver && api->min_ver <= match->max_ver; } unsigned kafka_get_features(kafka_api_version_t *api, size_t api_cnt) { unsigned features = 0; int i, fails, r; const kafka_api_version_t *match; for (i = 0; kafka_feature_map[i].feature != 0; i++) { fails = 0; for (match = &kafka_feature_map[i].depends[0]; match->api_key != -1 ; match++) { r = kafka_api_version_check(api, api_cnt, match); fails += !r; } if (!fails) features |= kafka_feature_map[i].feature; } return features; } int kafka_broker_get_api_version(const kafka_api_t *api, int api_key, int min_ver, int max_ver) { kafka_api_version_t sk = { .api_key = api_key }; kafka_api_version_t *retp; retp = bsearch(&sk, api->api, api->elements, sizeof(*api->api), kafka_api_version_key_cmp); if (!retp) return -1; if (retp->max_ver < max_ver) { if (retp->max_ver < min_ver) return -1; else return retp->max_ver; } else if (retp->min_ver > min_ver) return -1; else return max_ver; } void kafka_parser_init(kafka_parser_t *parser) { parser->complete = 0; parser->message_size = 0; parser->msgbuf = NULL; parser->cur_size = 0; parser->hsize = 0; } void kafka_parser_deinit(kafka_parser_t *parser) { free(parser->msgbuf); } void kafka_config_init(kafka_config_t *conf) { conf->produce_timeout = 100; conf->produce_msg_max_bytes = 1000000; conf->produce_msgset_cnt = 10000; conf->produce_msgset_max_bytes = 1000000; conf->fetch_timeout = 100; conf->fetch_min_bytes = 1; conf->fetch_max_bytes = 50 * 1024 * 1024; conf->fetch_msg_max_bytes = 10 * 1024 * 1024; conf->offset_timestamp = KAFKA_TIMESTAMP_LATEST; conf->commit_timestamp = 0; conf->session_timeout = 10*1000; conf->rebalance_timeout = 10000; conf->retention_time_period = 20000; conf->produce_acks = -1; conf->allow_auto_topic_creation = 1; conf->api_version_request = 0; conf->api_version_timeout = 10000; conf->broker_version = NULL; conf->compress_type = Kafka_NoCompress; conf->compress_level = 0; conf->client_id = NULL; conf->check_crcs = 0; conf->offset_store = KAFKA_OFFSET_AUTO; conf->rack_id = NULL; conf->mechanisms = NULL; conf->username = NULL; conf->password = NULL; conf->recv = NULL; conf->client_new = NULL; } void kafka_config_deinit(kafka_config_t *conf) { free(conf->broker_version); free(conf->client_id); free(conf->rack_id); free(conf->mechanisms); free(conf->username); free(conf->password); } void kafka_partition_init(kafka_partition_t *partition) { partition->error = 0; partition->partition_index = -1; kafka_broker_init(&partition->leader); partition->replica_nodes = NULL; partition->replica_node_elements = 0; partition->isr_nodes = NULL; partition->isr_node_elements = 0; } void kafka_partition_deinit(kafka_partition_t *partition) { kafka_broker_deinit(&partition->leader); free(partition->replica_nodes); free(partition->isr_nodes); } void kafka_api_init(kafka_api_t *api) { api->features = 0; api->api = NULL; api->elements = 0; } void kafka_api_deinit(kafka_api_t *api) { free(api->api); } void kafka_broker_init(kafka_broker_t *broker) { broker->node_id = -1; broker->port = 0; broker->host = NULL; broker->rack = NULL; broker->error = 0; broker->status = KAFKA_BROKER_UNINIT; } void kafka_broker_deinit(kafka_broker_t *broker) { free(broker->host); free(broker->rack); } void kafka_meta_init(kafka_meta_t *meta) { meta->error = 0; meta->topic_name = NULL; meta->error_message = NULL; meta->is_internal = 0; meta->partitions = NULL; meta->partition_elements = 0; } void kafka_meta_deinit(kafka_meta_t *meta) { int i; free(meta->topic_name); free(meta->error_message); for (i = 0; i < meta->partition_elements; ++i) { kafka_partition_deinit(meta->partitions[i]); free(meta->partitions[i]); } free(meta->partitions); } void kafka_topic_partition_init(kafka_topic_partition_t *toppar) { toppar->error = 0; toppar->topic_name = NULL; toppar->partition = -1; toppar->preferred_read_replica = -1; toppar->offset = KAFKA_OFFSET_UNINIT; toppar->high_watermark = KAFKA_OFFSET_UNINIT; toppar->low_watermark = KAFKA_OFFSET_UNINIT; toppar->last_stable_offset = -1; toppar->log_start_offset = -1; toppar->offset_timestamp = KAFKA_TIMESTAMP_UNINIT; toppar->committed_metadata = NULL; INIT_LIST_HEAD(&toppar->record_list); } void kafka_topic_partition_deinit(kafka_topic_partition_t *toppar) { free(toppar->topic_name); free(toppar->committed_metadata); } void kafka_record_header_init(kafka_record_header_t *header) { header->key = NULL; header->key_len = 0; header->key_is_moved = 0; header->value = NULL; header->value_len = 0; header->value_is_moved = 0; } void kafka_record_header_deinit(kafka_record_header_t *header) { if (!header->key_is_moved) free(header->key); if (!header->value_is_moved) free(header->value); } void kafka_record_init(kafka_record_t *record) { record->key = NULL; record->key_len = 0; record->key_is_moved = 0; record->value = NULL; record->value_len = 0; record->value_is_moved = 0; record->timestamp = 0; record->offset = 0; INIT_LIST_HEAD(&record->header_list); record->status = 0; record->toppar = NULL; } void kafka_record_deinit(kafka_record_t *record) { struct list_head *tmp, *pos; kafka_record_header_t *header; if (!record->key_is_moved) free(record->key); if (!record->value_is_moved) free(record->value); list_for_each_safe(pos, tmp, &record->header_list) { header = list_entry(pos, kafka_record_header_t, list); list_del(pos); kafka_record_header_deinit(header); free(header); } } void kafka_member_init(kafka_member_t *member) { member->member_id = NULL; member->client_id = NULL; member->client_host = NULL; member->member_metadata = NULL; member->member_metadata_len = 0; } void kafka_member_deinit(kafka_member_t *member) { free(member->member_id); free(member->client_id); free(member->client_host); //do not need free! //free(member->member_metadata); } void kafka_cgroup_init(kafka_cgroup_t *cgroup) { INIT_LIST_HEAD(&cgroup->assigned_toppar_list); cgroup->error = 0; cgroup->error_msg = NULL; kafka_broker_init(&cgroup->coordinator); cgroup->leader_id = NULL; cgroup->member_id = NULL; cgroup->members = NULL; cgroup->member_elements = 0; cgroup->generation_id = -1; cgroup->group_name = NULL; cgroup->protocol_type = "consumer"; cgroup->protocol_name = NULL; INIT_LIST_HEAD(&cgroup->group_protocol_list); } void kafka_cgroup_deinit(kafka_cgroup_t *cgroup) { int i; free(cgroup->error_msg); kafka_broker_deinit(&cgroup->coordinator); free(cgroup->leader_id); free(cgroup->member_id); for (i = 0; i < cgroup->member_elements; ++i) { kafka_member_deinit(cgroup->members[i]); free(cgroup->members[i]); } free(cgroup->members); free(cgroup->protocol_name); } void kafka_block_init(kafka_block_t *block) { block->buf = NULL; block->len = 0; block->is_moved = 0; } void kafka_block_deinit(kafka_block_t *block) { if (!block->is_moved) free(block->buf); } int kafka_parser_append_message(const void *buf, size_t *size, kafka_parser_t *parser) { size_t s = *size; int totaln; if (parser->complete) { *size = 0; return 1; } if (parser->hsize + s < 4) { memcpy(parser->headbuf + parser->hsize, buf, s); parser->hsize += s; return 0; } else if (!parser->msgbuf) { memcpy(parser->headbuf + parser->hsize, buf, 4 - parser->hsize); buf = (const char *)buf + 4 - parser->hsize; s -= 4 - parser->hsize; parser->hsize = 4; memcpy(&totaln, parser->headbuf, 4); parser->message_size = ntohl(totaln); parser->msgbuf = malloc(parser->message_size); if (!parser->msgbuf) return -1; parser->cur_size = 0; } if (s > parser->message_size - parser->cur_size) { memcpy((char *)parser->msgbuf + parser->cur_size, buf, parser->message_size - parser->cur_size); parser->cur_size = parser->message_size; } else { memcpy((char *)parser->msgbuf + parser->cur_size, buf, s); parser->cur_size += s; } if (parser->cur_size < parser->message_size) return 0; *size -= parser->message_size - parser->cur_size; return 1; } int kafka_topic_partition_set_tp(const char *topic_name, int partition, kafka_topic_partition_t *toppar) { char *p = strdup(topic_name); if (!p) return -1; free(toppar->topic_name); toppar->topic_name = p; toppar->partition = partition; return 0; } int kafka_record_set_key(const void *key, size_t key_len, kafka_record_t *record) { void *k = malloc(key_len); if (!k) return -1; free(record->key); memcpy(k, key, key_len); record->key = k; record->key_len = key_len; return 0; } int kafka_record_set_value(const void *val, size_t val_len, kafka_record_t *record) { void *v = malloc(val_len); if (!v) return -1; free(record->value); memcpy(v, val, val_len); record->value = v; record->value_len = val_len; return 0; } int kafka_record_header_set_kv(const void *key, size_t key_len, const void *val, size_t val_len, kafka_record_header_t *header) { void *k = malloc(key_len); void *v = malloc(val_len); if (!k || !v) { free(k); free(v); return -1; } memcpy(k, key, key_len); memcpy(v, val, val_len); header->key = k; header->key_len = key_len; header->value = v; header->value_len = val_len; return 0; } int kafka_meta_set_topic(const char *topic, kafka_meta_t *meta) { char *t = strdup(topic); if (!t) return -1; free(meta->topic_name); meta->topic_name = t; return 0; } int kafka_cgroup_set_group(const char *group, kafka_cgroup_t *cgroup) { char *t = strdup(group); if (!t) return -1; free(cgroup->group_name); cgroup->group_name = t; return 0; } static int kafka_sasl_plain_recv(const char *buf, size_t len, void *conf, void *q) { return 0; } static int kafka_sasl_plain_client_new(void *p, kafka_sasl_t *sasl) { kafka_config_t *conf = (kafka_config_t *)p; size_t ulen = strlen(conf->username); size_t plen = strlen(conf->password); size_t blen = ulen + plen + 2; size_t off = 0; char *buf = (char *)malloc(blen); if (!buf) return -1; buf[off++] = '\0'; memcpy(buf + off, conf->username, ulen); off += ulen; buf[off++] = '\0'; memcpy(buf + off, conf->password, plen); free(sasl->buf); sasl->buf = buf; sasl->bsize = blen; return 0; } static int scram_get_attr(const struct iovec *inbuf, char attr, struct iovec *outbuf) { const char *td; size_t len; size_t of = 0; void *ptr; char ochar, nchar; for (of = 0; of < inbuf->iov_len;) { ptr = (char *)inbuf->iov_base + of; td = memchr(ptr, ',', inbuf->iov_len - of); if (td) len = (size_t)((char *)td - (char *)inbuf->iov_base - of); else len = inbuf->iov_len - of; ochar = *((char *)inbuf->iov_base + of); nchar = *((char *)inbuf->iov_base + of + 1); if (ochar == attr && inbuf->iov_len > of + 1 && nchar == '=') { outbuf->iov_base = (char *)ptr + 2; outbuf->iov_len = len - 2; return 0; } of += len + 1; } return -1; } static char *scram_base64_encode(const struct iovec *in) { char *ret; size_t ret_len, max_len; if (in->iov_len > INT_MAX) return NULL; max_len = (((in->iov_len + 2) / 3) * 4) + 1; ret = malloc(max_len); if (!ret) return NULL; ret_len = EVP_EncodeBlock((uint8_t *)ret, (uint8_t *)in->iov_base, (int)in->iov_len); if (ret_len >= max_len) { free(ret); return NULL; } ret[ret_len] = 0; return ret; } static int scram_base64_decode(const struct iovec *in, struct iovec *out) { size_t ret_len; if (in->iov_len % 4 != 0 || in->iov_len > INT_MAX) return -1; ret_len = ((in->iov_len / 4) * 3); out->iov_base = malloc(ret_len + 1); if (!out->iov_base) return -1; if (EVP_DecodeBlock((uint8_t*)out->iov_base, (uint8_t*)in->iov_base, (int)in->iov_len) == -1) { free(out->iov_base); out->iov_base = NULL; return -1; } if (in->iov_len > 1 && ((char *)(in->iov_base))[in->iov_len - 1] == '=') { if (in->iov_len > 2 && ((char *)(in->iov_base))[in->iov_len - 2] == '=') ret_len -= 2; else ret_len -= 1; } ((char *)(out->iov_base))[ret_len] = '\0'; out->iov_len = ret_len; return 0; } static int scram_hi(const EVP_MD *evp, int itcnt, const struct iovec *in, const struct iovec *salt, struct iovec *out) { unsigned int ressize = 0; unsigned char tempres[EVP_MAX_MD_SIZE]; unsigned char tempdest[EVP_MAX_MD_SIZE]; unsigned char *saltplus; int i, j; saltplus = alloca(salt->iov_len + 4); if (!saltplus) return -1; memcpy(saltplus, salt->iov_base, salt->iov_len); saltplus[salt->iov_len] = '\0'; saltplus[salt->iov_len + 1] = '\0'; saltplus[salt->iov_len + 2] = '\0'; saltplus[salt->iov_len + 3] = '\1'; if (!HMAC(evp, (const unsigned char *)in->iov_base, (int)in->iov_len, saltplus, salt->iov_len + 4, tempres, &ressize)) { return -1; } memcpy(out->iov_base, tempres, ressize); for (i = 1; i < itcnt; i++) { if (!HMAC(evp, (const unsigned char *)in->iov_base, (int)in->iov_len, tempres, ressize, tempdest, NULL)) { return -1; } for (j = 0; j < (int)ressize; j++) { ((char *)(out->iov_base))[j] ^= tempdest[j]; tempres[j] = tempdest[j]; } } out->iov_len = ressize; return 0; } static int scram_hmac(const EVP_MD *evp, const struct iovec *key, const struct iovec *str, struct iovec *out) { unsigned int outsize; if (!HMAC(evp, (const unsigned char *)key->iov_base, (int)key->iov_len, (const unsigned char *)str->iov_base, (int)str->iov_len, (unsigned char *)out->iov_base, &outsize)) { return -1; } out->iov_len = outsize; return 0; } static void scram_h(kafka_scram_t *scram, const struct iovec *str, struct iovec *out) { scram->scram_h((const unsigned char *)str->iov_base, str->iov_len, (unsigned char *)out->iov_base); out->iov_len = scram->scram_h_size; } static void scram_build_client_final_message_wo_proof( kafka_scram_t *scram, const struct iovec *snonce, struct iovec *out) { const char *attr_c = "biws"; out->iov_len = 9 + scram->cnonce.iov_len + snonce->iov_len; out->iov_base = malloc(out->iov_len + 1); if (out->iov_base) { snprintf((char *)out->iov_base, out->iov_len + 1, "c=%s,r=%.*s%.*s", attr_c, (int)scram->cnonce.iov_len, (char *)scram->cnonce.iov_base, (int)snonce->iov_len, (char *)snonce->iov_base); } } static int scram_build_client_final_message(kafka_scram_t *scram, int itcnt, const struct iovec *salt, const struct iovec *server_first_msg, const struct iovec *server_nonce, struct iovec *out, const kafka_config_t *conf) { char salted_pwd[EVP_MAX_MD_SIZE]; char client_key[EVP_MAX_MD_SIZE]; char server_key[EVP_MAX_MD_SIZE]; char stored_key[EVP_MAX_MD_SIZE]; char client_sign[EVP_MAX_MD_SIZE]; char server_sign[EVP_MAX_MD_SIZE]; char client_proof[EVP_MAX_MD_SIZE]; struct iovec password_iov = {conf->password, strlen(conf->password)}; struct iovec salted_pwd_iov = {salted_pwd, EVP_MAX_MD_SIZE}; struct iovec client_key_verbatim_iov = {"Client Key", 10}; struct iovec server_key_verbatim_iov = {"Server Key", 10}; struct iovec client_key_iov = {client_key, EVP_MAX_MD_SIZE}; struct iovec server_key_iov = {server_key, EVP_MAX_MD_SIZE}; struct iovec stored_key_iov = {stored_key, EVP_MAX_MD_SIZE}; struct iovec server_sign_iov = {server_sign, EVP_MAX_MD_SIZE}; struct iovec client_sign_iov = {client_sign, EVP_MAX_MD_SIZE}; struct iovec client_proof_iov = {client_proof, EVP_MAX_MD_SIZE}; struct iovec client_final_msg_wo_proof_iov; struct iovec auth_message_iov; char *server_sign_b64, *client_proof_b64 = NULL; int i; if (scram_hi((const EVP_MD *)scram->evp, itcnt, &password_iov, salt, &salted_pwd_iov) == -1) return -1; if (scram_hmac((const EVP_MD *)scram->evp, &salted_pwd_iov, &client_key_verbatim_iov, &client_key_iov) == -1) return -1; scram_h(scram, &client_key_iov, &stored_key_iov); scram_build_client_final_message_wo_proof(scram, server_nonce, &client_final_msg_wo_proof_iov); auth_message_iov.iov_len = scram->first_msg.iov_len + 1 + server_first_msg->iov_len + 1 + client_final_msg_wo_proof_iov.iov_len; auth_message_iov.iov_base = alloca(auth_message_iov.iov_len + 1); if (auth_message_iov.iov_base) { snprintf(auth_message_iov.iov_base, auth_message_iov.iov_len + 1, "%.*s,%.*s,%.*s", (int)scram->first_msg.iov_len, (char *)scram->first_msg.iov_base, (int)server_first_msg->iov_len, (char *)server_first_msg->iov_base, (int)client_final_msg_wo_proof_iov.iov_len, (char *)client_final_msg_wo_proof_iov.iov_base); if (scram_hmac((const EVP_MD *)scram->evp, &salted_pwd_iov, &server_key_verbatim_iov, &server_key_iov) == 0 && scram_hmac((const EVP_MD *)scram->evp, &server_key_iov, &auth_message_iov, &server_sign_iov) == 0) { server_sign_b64 = scram_base64_encode(&server_sign_iov); if (server_sign_b64 && scram_hmac((const EVP_MD *)scram->evp, &stored_key_iov, &auth_message_iov, &client_sign_iov) ==0 && client_key_iov.iov_len == client_sign_iov.iov_len) { scram->server_signature_b64.iov_base = server_sign_b64; scram->server_signature_b64.iov_len = strlen(server_sign_b64); for (i = 0 ; i < (int)client_key_iov.iov_len; i++) ((char *)(client_proof_iov.iov_base))[i] = ((char *)(client_key_iov.iov_base))[i] ^ ((char *)(client_sign_iov.iov_base))[i]; client_proof_iov.iov_len = client_key_iov.iov_len; client_proof_b64 = scram_base64_encode(&client_proof_iov); if (client_proof_b64) { out->iov_len = client_final_msg_wo_proof_iov.iov_len + 3 + strlen(client_proof_b64); out->iov_base = malloc(out->iov_len + 1); snprintf((char *)out->iov_base, out->iov_len + 1, "%.*s,p=%s", (int)client_final_msg_wo_proof_iov.iov_len, (char *)client_final_msg_wo_proof_iov.iov_base, client_proof_b64); } } } } free(client_proof_b64); free(client_final_msg_wo_proof_iov.iov_base); return 0; } static int scram_handle_server_first_message(const char *buf, size_t len, kafka_config_t *conf, kafka_sasl_t *sasl) { int itcnt; int ret = -1; const char *endptr; struct iovec out, salt, server_nonce; const struct iovec in = {(void *)buf, len}; if (scram_get_attr(&in, 'm', &out) == 0) return -1; if (scram_get_attr(&in, 'r', &server_nonce) != 0) return -1; if (server_nonce.iov_len <= sasl->scram.cnonce.iov_len || memcmp(server_nonce.iov_base, sasl->scram.cnonce.iov_base, sasl->scram.cnonce.iov_len) != 0) { return -1; } if (scram_get_attr(&in, 's', &out) != 0) return -1; if (scram_base64_decode(&out, &salt) != 0) return -1; if (scram_get_attr(&in, 'i', &out) == 0) { itcnt = (int)strtoul((const char *)out.iov_base, (char **)&endptr, 10); if ((const char *)out.iov_base != endptr && *endptr == '\0' && itcnt <= 1000000) { ret = scram_build_client_final_message(&sasl->scram, itcnt, &salt, &in, &server_nonce, &out, conf); if (ret == 0) { free(sasl->buf); sasl->buf = out.iov_base; sasl->bsize = out.iov_len; } } } free(salt.iov_base); return ret; } static int scram_handle_server_final_message(const char *buf, size_t len, kafka_config_t *conf, kafka_sasl_t *sasl) { struct iovec attr_v, attr_e; const struct iovec in = {(void *)buf, len}; if (scram_get_attr(&in, 'm', &attr_e) == 0) return -1; if (scram_get_attr(&in, 'v', &attr_v) == 0) { if (sasl->scram.server_signature_b64.iov_len == attr_v.iov_len && strncmp((const char *)sasl->scram.server_signature_b64.iov_base, (const char *)attr_v.iov_base, attr_v.iov_len) != 0) { return -1; } } return 0; } static int kafka_sasl_scram_recv(const char *buf, size_t len, void *p, void *q) { kafka_config_t *conf = (kafka_config_t *)p; kafka_sasl_t *sasl = (kafka_sasl_t *)q; int ret = -1; switch(sasl->scram.state) { case KAFKA_SASL_SCRAM_STATE_SERVER_FIRST_MESSAGE: ret = scram_handle_server_first_message(buf, len, conf, sasl); sasl->scram.state = KAFKA_SASL_SCRAM_STATE_CLIENT_FINAL_MESSAGE; break; case KAFKA_SASL_SCRAM_STATE_CLIENT_FINAL_MESSAGE: ret = scram_handle_server_final_message(buf, len, conf, sasl); sasl->scram.state = KAFKA_SASL_SCRAM_STATE_CLIENT_FINISHED; break; default: break; } return ret; } static int jitter(int low, int high) { return (low + (rand() % ((high - low) + 1))); } static int scram_generate_nonce(struct iovec *iov) { int i; char *ptr = (char *)malloc(33); if (!ptr) return -1; for (i = 0; i < 32; i++) ptr[i] = jitter(0x2d, 0x7e); ptr[32] = '\0'; iov->iov_base = ptr; iov->iov_len = 32; return 0; } static int kafka_sasl_scram_client_new(void *p, kafka_sasl_t *sasl) { kafka_config_t *conf = (kafka_config_t *)p; size_t ulen = strlen(conf->username); size_t tlen = strlen("n,,n=,r="); size_t olen = ulen + tlen + 32; void *ptr; if (sasl->scram.state != KAFKA_SASL_SCRAM_STATE_CLIENT_FIRST_MESSAGE) return -1; if (scram_generate_nonce(&sasl->scram.cnonce) != 0) return -1; ptr = malloc(olen + 1); if (!ptr) return -1; snprintf(ptr, olen + 1, "n,,n=%s,r=%.*s", conf->username, (int)sasl->scram.cnonce.iov_len, (char *)sasl->scram.cnonce.iov_base); sasl->buf = ptr; sasl->bsize = olen; sasl->scram.first_msg.iov_base = (char *)ptr + 3; sasl->scram.first_msg.iov_len = olen - 3; sasl->scram.state = KAFKA_SASL_SCRAM_STATE_SERVER_FIRST_MESSAGE; return 0; } int kafka_sasl_set_mechanisms(kafka_config_t *conf) { if (strcasecmp(conf->mechanisms, "plain") == 0) { conf->recv = kafka_sasl_plain_recv; conf->client_new = kafka_sasl_plain_client_new; return 0; } else if (strncasecmp(conf->mechanisms, "SCRAM", 5) == 0) { conf->recv = kafka_sasl_scram_recv; conf->client_new = kafka_sasl_scram_client_new; } return -1; } void kafka_sasl_init(kafka_sasl_t *sasl) { sasl->scram.evp = NULL; sasl->scram.scram_h = NULL; sasl->scram.scram_h_size = 0; sasl->scram.state = KAFKA_SASL_SCRAM_STATE_CLIENT_FIRST_MESSAGE; sasl->scram.cnonce.iov_base = NULL; sasl->scram.cnonce.iov_len = 0; sasl->scram.first_msg.iov_base = NULL; sasl->scram.first_msg.iov_len = 0; sasl->scram.server_signature_b64.iov_base = NULL; sasl->scram.server_signature_b64.iov_len = 0; sasl->buf = NULL; sasl->bsize = 0; sasl->status = 0; } void kafka_sasl_deinit(kafka_sasl_t *sasl) { free(sasl->scram.cnonce.iov_base); free(sasl->scram.server_signature_b64.iov_base); free(sasl->buf); } int kafka_sasl_set_username(const char *username, kafka_config_t *conf) { char *t = strdup(username); if (!t) return -1; free(conf->username); conf->username = t; return 0; } int kafka_sasl_set_password(const char *password, kafka_config_t *conf) { char *t = strdup(password); if (!t) return -1; free(conf->password); conf->password = t; return 0; } workflow-0.11.8/src/protocol/kafka_parser.h000066400000000000000000000304121476003635400207150ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wang Zhulei (wangzhulei@sogou-inc.com) */ #ifndef _KAFKA_PARSER_H_ #define _KAFKA_PARSER_H_ #include #include #include #include "list.h" enum { KAFKA_UNKNOWN_SERVER_ERROR = -1, KAFKA_OFFSET_OUT_OF_RANGE = 1, KAFKA_CORRUPT_MESSAGE = 2, KAFKA_UNKNOWN_TOPIC_OR_PARTITION = 3, KAFKA_INVALID_FETCH_SIZE = 4, KAFKA_LEADER_NOT_AVAILABLE = 5, KAFKA_NOT_LEADER_FOR_PARTITION = 6, KAFKA_REQUEST_TIMED_OUT = 7, KAFKA_BROKER_NOT_AVAILABLE = 8, KAFKA_REPLICA_NOT_AVAILABLE = 9, KAFKA_MESSAGE_TOO_LARGE = 10, KAFKA_STALE_CONTROLLER_EPOCH = 11, KAFKA_OFFSET_METADATA_TOO_LARGE = 12, KAFKA_NETWORK_EXCEPTION = 13, KAFKA_COORDINATOR_LOAD_IN_PROGRESS = 14, KAFKA_COORDINATOR_NOT_AVAILABLE = 15, KAFKA_NOT_COORDINATOR = 16, KAFKA_INVALID_TOPIC_EXCEPTION = 17, KAFKA_RECORD_LIST_TOO_LARGE = 18, KAFKA_NOT_ENOUGH_REPLICAS = 19, KAFKA_NOT_ENOUGH_REPLICAS_AFTER_APPEND = 20, KAFKA_INVALID_REQUIRED_ACKS = 21, KAFKA_ILLEGAL_GENERATION = 22, KAFKA_INCONSISTENT_GROUP_PROTOCOL = 23, KAFKA_INVALID_GROUP_ID = 24, KAFKA_UNKNOWN_MEMBER_ID = 25, KAFKA_INVALID_SESSION_TIMEOUT = 26, KAFKA_REBALANCE_IN_PROGRESS = 27, KAFKA_INVALID_COMMIT_OFFSET_SIZE = 28, KAFKA_TOPIC_AUTHORIZATION_FAILED = 29, KAFKA_GROUP_AUTHORIZATION_FAILED = 30, KAFKA_CLUSTER_AUTHORIZATION_FAILED = 31, KAFKA_INVALID_TIMESTAMP = 32, KAFKA_UNSUPPORTED_SASL_MECHANISM = 33, KAFKA_ILLEGAL_SASL_STATE = 34, KAFKA_UNSUPPORTED_VERSION = 35, KAFKA_TOPIC_ALREADY_EXISTS = 36, KAFKA_INVALID_PARTITIONS = 37, KAFKA_INVALID_REPLICATION_FACTOR = 38, KAFKA_INVALID_REPLICA_ASSIGNMENT = 39, KAFKA_INVALID_CONFIG = 40, KAFKA_NOT_CONTROLLER = 41, KAFKA_INVALID_REQUEST = 42, KAFKA_UNSUPPORTED_FOR_MESSAGE_FORMAT = 43, KAFKA_POLICY_VIOLATION = 44, KAFKA_OUT_OF_ORDER_SEQUENCE_NUMBER = 45, KAFKA_DUPLICATE_SEQUENCE_NUMBER = 46, KAFKA_INVALID_PRODUCER_EPOCH = 47, KAFKA_INVALID_TXN_STATE = 48, KAFKA_INVALID_PRODUCER_ID_MAPPING = 49, KAFKA_INVALID_TRANSACTION_TIMEOUT = 50, KAFKA_CONCURRENT_TRANSACTIONS = 51, KAFKA_TRANSACTION_COORDINATOR_FENCED = 52, KAFKA_TRANSACTIONAL_ID_AUTHORIZATION_FAILED = 53, KAFKA_SECURITY_DISABLED = 54, KAFKA_OPERATION_NOT_ATTEMPTED = 55, KAFKA_KAFKA_STORAGE_ERROR = 56, KAFKA_LOG_DIR_NOT_FOUND = 57, KAFKA_SASL_AUTHENTICATION_FAILED = 58, KAFKA_UNKNOWN_PRODUCER_ID = 59, KAFKA_REASSIGNMENT_IN_PROGRESS = 60, KAFKA_DELEGATION_TOKEN_AUTH_DISABLED = 61, KAFKA_DELEGATION_TOKEN_NOT_FOUND = 62, KAFKA_DELEGATION_TOKEN_OWNER_MISMATCH = 63, KAFKA_DELEGATION_TOKEN_REQUEST_NOT_ALLOWED = 64, KAFKA_DELEGATION_TOKEN_AUTHORIZATION_FAILED = 65, KAFKA_DELEGATION_TOKEN_EXPIRED = 66, KAFKA_INVALID_PRINCIPAL_TYPE = 67, KAFKA_NON_EMPTY_GROUP = 68, KAFKA_GROUP_ID_NOT_FOUND = 69, KAFKA_FETCH_SESSION_ID_NOT_FOUND = 70, KAFKA_INVALID_FETCH_SESSION_EPOCH = 71, KAFKA_LISTENER_NOT_FOUND = 72, KAFKA_TOPIC_DELETION_DISABLED = 73, KAFKA_FENCED_LEADER_EPOCH = 74, KAFKA_UNKNOWN_LEADER_EPOCH = 75, KAFKA_UNSUPPORTED_COMPRESSION_TYPE = 76, KAFKA_STALE_BROKER_EPOCH = 77, KAFKA_OFFSET_NOT_AVAILABLE = 78, KAFKA_MEMBER_ID_REQUIRED = 79, KAFKA_PREFERRED_LEADER_NOT_AVAILABLE = 80, KAFKA_GROUP_MAX_SIZE_REACHED = 81, KAFKA_FENCED_INSTANCE_ID = 82, }; enum { Kafka_Unknown = -1, Kafka_Produce = 0, Kafka_Fetch = 1, Kafka_ListOffsets = 2, Kafka_Metadata = 3, Kafka_LeaderAndIsr = 4, Kafka_StopReplica = 5, Kafka_UpdateMetadata = 6, Kafka_ControlledShutdown = 7, Kafka_OffsetCommit = 8, Kafka_OffsetFetch = 9, Kafka_FindCoordinator = 10, Kafka_JoinGroup = 11, Kafka_Heartbeat = 12, Kafka_LeaveGroup = 13, Kafka_SyncGroup = 14, Kafka_DescribeGroups = 15, Kafka_ListGroups = 16, Kafka_SaslHandshake = 17, Kafka_ApiVersions = 18, Kafka_CreateTopics = 19, Kafka_DeleteTopics = 20, Kafka_DeleteRecords = 21, Kafka_InitProducerId = 22, Kafka_OffsetForLeaderEpoch = 23, Kafka_AddPartitionsToTxn = 24, Kafka_AddOffsetsToTxn = 25, Kafka_EndTxn = 26, Kafka_WriteTxnMarkers = 27, Kafka_TxnOffsetCommit = 28, Kafka_DescribeAcls = 29, Kafka_CreateAcls = 30, Kafka_DeleteAcls = 31, Kafka_DescribeConfigs = 32, Kafka_AlterConfigs = 33, Kafka_AlterReplicaLogDirs = 34, Kafka_DescribeLogDirs = 35, Kafka_SaslAuthenticate = 36, Kafka_CreatePartitions = 37, Kafka_CreateDelegationToken = 38, Kafka_RenewDelegationToken = 39, Kafka_ExpireDelegationToken = 40, Kafka_DescribeDelegationToken = 41, Kafka_DeleteGroups = 42, Kafka_ElectPreferredLeaders = 43, Kafka_IncrementalAlterConfigs = 44, Kafka_ApiNums, }; enum { Kafka_NoCompress, Kafka_Gzip, Kafka_Snappy, Kafka_Lz4, Kafka_Zstd, }; enum { KAFKA_FEATURE_APIVERSION = 1<<0, KAFKA_FEATURE_BROKER_BALANCED_CONSUMER = 1<<1, KAFKA_FEATURE_THROTTLETIME = 1<<2, KAFKA_FEATURE_BROKER_GROUP_COORD = 1<<3, KAFKA_FEATURE_LZ4 = 1<<4, KAFKA_FEATURE_OFFSET_TIME = 1<<5, KAFKA_FEATURE_MSGVER2 = 1<<6, KAFKA_FEATURE_MSGVER1 = 1<<7, KAFKA_FEATURE_ZSTD = 1<<8, KAFKA_FEATURE_SASL_GSSAPI = 1<<9, KAFKA_FEATURE_SASL_HANDSHAKE = 1<<10, KAFKA_FEATURE_SASL_AUTH_REQ = 1<<11, }; enum { KAFKA_OFFSET_AUTO, KAFKA_OFFSET_ASSIGN, }; enum { KAFKA_BROKER_UNINIT, KAFKA_BROKER_DOING, KAFKA_BROKER_INITED, }; enum { KAFKA_TIMESTAMP_EARLIEST = -2, KAFKA_TIMESTAMP_LATEST = -1, KAFKA_TIMESTAMP_UNINIT = 0, }; enum { KAFKA_OFFSET_UNINIT = -2, KAFKA_OFFSET_OVERFLOW = -1, }; typedef struct __kafka_api_version { short api_key; short min_ver; short max_ver; } kafka_api_version_t; typedef struct __kafka_api_t { unsigned features; kafka_api_version_t *api; int elements; } kafka_api_t; typedef struct __kafka_parser { int complete; size_t message_size; void *msgbuf; size_t cur_size; char headbuf[4]; size_t hsize; } kafka_parser_t; enum __kafka_scram_state { KAFKA_SASL_SCRAM_STATE_CLIENT_FIRST_MESSAGE, KAFKA_SASL_SCRAM_STATE_SERVER_FIRST_MESSAGE, KAFKA_SASL_SCRAM_STATE_CLIENT_FINAL_MESSAGE, KAFKA_SASL_SCRAM_STATE_CLIENT_FINISHED, }; typedef struct __kafka_scram { const void *evp; unsigned char *(*scram_h)(const unsigned char *d, size_t n, unsigned char *md); size_t scram_h_size; enum __kafka_scram_state state; struct iovec cnonce; struct iovec first_msg; struct iovec server_signature_b64; } kafka_scram_t; typedef struct __kafka_sasl { kafka_scram_t scram; char *buf; size_t bsize; int status; } kafka_sasl_t; typedef struct __kafka_config { int produce_timeout; int produce_msg_max_bytes; int produce_msgset_cnt; int produce_msgset_max_bytes; int fetch_timeout; int fetch_min_bytes; int fetch_max_bytes; int fetch_msg_max_bytes; long long offset_timestamp; long long commit_timestamp; int session_timeout; int rebalance_timeout; long long retention_time_period; int produce_acks; int allow_auto_topic_creation; int api_version_request; int api_version_timeout; char *broker_version; int compress_type; int compress_level; char *client_id; int check_crcs; int offset_store; char *rack_id; char *mechanisms; char *username; char *password; int (*client_new)(void *conf, kafka_sasl_t *sasl); int (*recv)(const char *buf, size_t len, void *conf, void *sasl); } kafka_config_t; typedef struct __kafka_broker { int node_id; int port; char *host; char *rack; short error; int status; } kafka_broker_t; typedef struct __kafka_partition { short error; int partition_index; kafka_broker_t leader; int *replica_nodes; int replica_node_elements; int *isr_nodes; int isr_node_elements; } kafka_partition_t; typedef struct __kafka_meta { short error; char *topic_name; char *error_message; signed char is_internal; kafka_partition_t **partitions; int partition_elements; } kafka_meta_t; typedef struct __kafka_topic_partition { short error; char *topic_name; int partition; int preferred_read_replica; long long offset; long long high_watermark; long long low_watermark; long long last_stable_offset; long long log_start_offset; long long offset_timestamp; char *committed_metadata; struct list_head record_list; } kafka_topic_partition_t; typedef struct __kafka_record_header { struct list_head list; void *key; size_t key_len; int key_is_moved; void *value; size_t value_len; int value_is_moved; } kafka_record_header_t; typedef struct __kafka_record { void *key; size_t key_len; int key_is_moved; void *value; size_t value_len; int value_is_moved; long long timestamp; long long offset; struct list_head header_list; short status; kafka_topic_partition_t *toppar; } kafka_record_t; typedef struct __kafka_memeber { char *member_id; char *client_id; char *client_host; void *member_metadata; size_t member_metadata_len; struct list_head toppar_list; struct list_head assigned_toppar_list; } kafka_member_t; typedef int (*kafka_assignor_t)(kafka_member_t **members, int member_elements, void *meta_topic); typedef struct __kafka_group_protocol { struct list_head list; char *protocol_name; kafka_assignor_t assignor; } kafka_group_protocol_t; typedef struct __kafka_cgroup { struct list_head assigned_toppar_list; short error; char *error_msg; kafka_broker_t coordinator; char *leader_id; char *member_id; kafka_member_t **members; int member_elements; int generation_id; char *group_name; char *protocol_type; char *protocol_name; struct list_head group_protocol_list; } kafka_cgroup_t; typedef struct __kafka_block { void *buf; size_t len; int is_moved; } kafka_block_t; #ifdef __cplusplus extern "C" { #endif int kafka_parser_append_message(const void *buf, size_t *size, kafka_parser_t *parser); void kafka_parser_init(kafka_parser_t *parser); void kafka_parser_deinit(kafka_parser_t *parser); void kafka_topic_partition_init(kafka_topic_partition_t *toppar); void kafka_topic_partition_deinit(kafka_topic_partition_t *toppar); void kafka_cgroup_init(kafka_cgroup_t *cgroup); void kafka_cgroup_deinit(kafka_cgroup_t *cgroup); void kafka_block_init(kafka_block_t *block); void kafka_block_deinit(kafka_block_t *block); void kafka_broker_init(kafka_broker_t *brock); void kafka_broker_deinit(kafka_broker_t *broker); void kafka_config_init(kafka_config_t *config); void kafka_config_deinit(kafka_config_t *config); void kafka_meta_init(kafka_meta_t *meta); void kafka_meta_deinit(kafka_meta_t *meta); void kafka_partition_init(kafka_partition_t *partition); void kafka_partition_deinit(kafka_partition_t *partition); void kafka_member_init(kafka_member_t *member); void kafka_member_deinit(kafka_member_t *member); void kafka_record_init(kafka_record_t *record); void kafka_record_deinit(kafka_record_t *record); void kafka_record_header_init(kafka_record_header_t *header); void kafka_record_header_deinit(kafka_record_header_t *header); void kafka_api_init(kafka_api_t *api); void kafka_api_deinit(kafka_api_t *api); void kafka_sasl_init(kafka_sasl_t *sasl); void kafka_sasl_deinit(kafka_sasl_t *sasl); int kafka_topic_partition_set_tp(const char *topic_name, int partition, kafka_topic_partition_t *toppar); int kafka_record_set_key(const void *key, size_t key_len, kafka_record_t *record); int kafka_record_set_value(const void *val, size_t val_len, kafka_record_t *record); int kafka_record_header_set_kv(const void *key, size_t key_len, const void *val, size_t val_len, kafka_record_header_t *header); int kafka_meta_set_topic(const char *topic_name, kafka_meta_t *meta); int kafka_cgroup_set_group(const char *group_name, kafka_cgroup_t *cgroup); int kafka_broker_get_api_version(const kafka_api_t *broker, int api_key, int min_ver, int max_ver); unsigned kafka_get_features(kafka_api_version_t *api, size_t api_cnt); int kafka_api_version_is_queryable(const char *broker_version, kafka_api_version_t **api, size_t *api_cnt); int kafka_sasl_set_mechanisms(kafka_config_t *conf); int kafka_sasl_set_username(const char *username, kafka_config_t *conf); int kafka_sasl_set_password(const char *passwd, kafka_config_t *conf); #ifdef __cplusplus } #endif #endif workflow-0.11.8/src/protocol/mysql_byteorder.c000066400000000000000000000031661476003635400215110ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) */ #include "mysql_byteorder.h" int decode_length_safe(unsigned long long *res, const unsigned char **pos, const unsigned char *end) { const unsigned char *p = *pos; if (p >= end) return 0; switch (*p) { default: *res = *p; *pos = p + 1; break; case 251: *res = (~0ULL); *pos = p + 1; break; case 252: if (p + 2 > end) return 0; *res = uint2korr(p + 1); *pos = p + 3; break; case 253: if (p + 3 > end) return 0; *res = uint3korr(p + 1); *pos = p + 4; break; case 254: if (p + 8 > end) return 0; *res = uint8korr(p + 1); *pos = p + 9; break; case 255: return -1; } return 1; } int decode_string(const unsigned char **str, unsigned long long *len, const unsigned char **pos, const unsigned char *end) { unsigned long long length; if (decode_length_safe(&length, pos, end) <= 0) return 0; if (length == (~0ULL)) length = 0; if (*pos + length > end) return 0; *len = length; *str = *pos; *pos = *pos + length; return 1; } workflow-0.11.8/src/protocol/mysql_byteorder.h000066400000000000000000000243311476003635400215130ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) */ #ifndef _MYSQL_BYTEORDER_H_ #define _MYSQL_BYTEORDER_H_ #include #include #include #if __BYTE_ORDER == __LITTLE_ENDIAN static inline int16_t sint2korr(const unsigned char *A) { int16_t ret; memcpy(&ret, A, sizeof(ret)); return ret; } static inline int32_t sint4korr(const unsigned char *A) { int32_t ret; memcpy(&ret, A, sizeof(ret)); return ret; } static inline uint16_t uint2korr(const unsigned char *A) { uint16_t ret; memcpy(&ret, A, sizeof(ret)); return ret; } static inline uint32_t uint4korr(const unsigned char *A) { uint32_t ret; memcpy(&ret, A, sizeof(ret)); return ret; } static inline uint64_t uint8korr(const unsigned char *A) { uint64_t ret; memcpy(&ret, A, sizeof(ret)); return ret; } static inline int64_t sint8korr(const unsigned char *A) { int64_t ret; memcpy(&ret, A, sizeof(ret)); return ret; } static inline void int2store(unsigned char *T, uint16_t A) { memcpy(T, &A, sizeof(A)); } static inline void int4store(unsigned char *T, uint32_t A) { memcpy(T, &A, sizeof(A)); } static inline void int7store(unsigned char *T, uint64_t A) { memcpy(T, &A, 7); } static inline void int8store(unsigned char *T, uint64_t A) { memcpy(T, &A, sizeof(A)); } static inline void float4get(float *V, const unsigned char *M) { memcpy(V, (M), sizeof(float)); } static inline void float4store(unsigned char *V, float M) { memcpy(V, (&M), sizeof(float)); } static inline void float8get(double *V, const unsigned char *M) { memcpy(V, M, sizeof(double)); } static inline void float8store(unsigned char *V, double M) { memcpy(V, &M, sizeof(double)); } static inline void floatget(float *V, const unsigned char *M) { float4get(V, M); } static inline void floatstore(unsigned char *V, float M) { float4store(V, M); } static inline void doublestore(unsigned char *T, double V) { memcpy(T, &V, sizeof(double)); } static inline void doubleget(double *V, const unsigned char *M) { memcpy(V, M, sizeof(double)); } static inline void ushortget(uint16_t *V, const unsigned char *pM) { *V = uint2korr(pM); } static inline void shortget(int16_t *V, const unsigned char *pM) { *V = sint2korr(pM); } static inline void longget(int32_t *V, const unsigned char *pM) { *V = sint4korr(pM); } static inline void ulongget(uint32_t *V, const unsigned char *pM) { *V = uint4korr(pM); } static inline void shortstore(unsigned char *T, int16_t V) { int2store(T, V); } static inline void longstore(unsigned char *T, int32_t V) { int4store(T, V); } static inline void longlongget(int64_t *V, const unsigned char *M) { memcpy(V, (M), sizeof(uint64_t)); } static inline void longlongstore(unsigned char *T, int64_t V) { memcpy((T), &V, sizeof(uint64_t)); } static inline int32_t sint3korr(const unsigned char *A) { int32_t ret = 0; memcpy(&ret, A, 3); return ret; } static inline uint32_t uint3korr(const unsigned char *A) { uint32_t ret = 0; memcpy(&ret, A, 3); return ret; } static inline void int3store(unsigned char *T, uint32_t A) { memcpy(T, &A, 3); } #elif __BYTE_ORDER == __BIG_ENDIAN static inline int16_t sint2korr(const unsigned char *A) { return (int16_t)(((int16_t)(A[0])) + ((int16_t)(A[1]) << 8)); } static inline int32_t sint4korr(const unsigned char *A) { return (int32_t)(((int32_t)(A[0])) + (((int32_t)(A[1]) << 8)) + (((int32_t)(A[2]) << 16)) + (((int32_t)(A[3]) << 24))); } static inline uint16_t uint2korr(const unsigned char *A) { return (uint16_t)(((uint16_t)(A[0])) + ((uint16_t)(A[1]) << 8)); } static inline uint32_t uint4korr(const unsigned char *A) { return (uint32_t)(((uint32_t)(A[0])) + (((uint32_t)(A[1])) << 8) + (((uint32_t)(A[2])) << 16) + (((uint32_t)(A[3])) << 24)); } static inline uint64_t uint8korr(const unsigned char *A) { return ((uint64_t)(((uint32_t)(A[0])) + (((uint32_t)(A[1])) << 8) + (((uint32_t)(A[2])) << 16) + (((uint32_t)(A[3])) << 24)) + (((uint64_t)(((uint32_t)(A[4])) + (((uint32_t)(A[5])) << 8) + (((uint32_t)(A[6])) << 16) + (((uint32_t)(A[7])) << 24))) << 32)); } static inline int64_t sint8korr(const unsigned char *A) { return (int64_t)uint8korr(A); } static inline void int2store(unsigned char *T, uint16_t A) { uint def_temp = A; *(T) = (unsigned char)(def_temp); *(T + 1) = (unsigned char)(def_temp >> 8); } static inline void int4store(unsigned char *T, uint32_t A) { *(T) = (unsigned char)(A); *(T + 1) = (unsigned char)(A >> 8); *(T + 2) = (unsigned char)(A >> 16); *(T + 3) = (unsigned char)(A >> 24); } static inline void int7store(unsigned char *T, uint64_t A) { *(T) = (unsigned char)(A); *(T + 1) = (unsigned char)(A >> 8); *(T + 2) = (unsigned char)(A >> 16); *(T + 3) = (unsigned char)(A >> 24); *(T + 4) = (unsigned char)(A >> 32); *(T + 5) = (unsigned char)(A >> 40); *(T + 6) = (unsigned char)(A >> 48); } static inline void int8store(unsigned char *T, uint64_t A) { uint def_temp = (uint)A, def_temp2 = (uint)(A >> 32); int4store(T, def_temp); int4store(T + 4, def_temp2); } static inline void float4store(unsigned char *T, float A) { *(T) = ((unsigned char *)&A)[3]; *((T) + 1) = (unsigned char)((unsigned char *)&A)[2]; *((T) + 2) = (unsigned char)((unsigned char *)&A)[1]; *((T) + 3) = (unsigned char)((unsigned char *)&A)[0]; } static inline void float4get(float *V, const unsigned char *M) { float def_temp; ((unsigned char *)&def_temp)[0] = (M)[3]; ((unsigned char *)&def_temp)[1] = (M)[2]; ((unsigned char *)&def_temp)[2] = (M)[1]; ((unsigned char *)&def_temp)[3] = (M)[0]; (*V) = def_temp; } static inline void float8store(unsigned char *T, double V) { *(T) = ((unsigned char *)&V)[7]; *((T) + 1) = (unsigned char)((unsigned char *)&V)[6]; *((T) + 2) = (unsigned char)((unsigned char *)&V)[5]; *((T) + 3) = (unsigned char)((unsigned char *)&V)[4]; *((T) + 4) = (unsigned char)((unsigned char *)&V)[3]; *((T) + 5) = (unsigned char)((unsigned char *)&V)[2]; *((T) + 6) = (unsigned char)((unsigned char *)&V)[1]; *((T) + 7) = (unsigned char)((unsigned char *)&V)[0]; } static inline void float8get(double *V, const unsigned char *M) { double def_temp; ((unsigned char *)&def_temp)[0] = (M)[7]; ((unsigned char *)&def_temp)[1] = (M)[6]; ((unsigned char *)&def_temp)[2] = (M)[5]; ((unsigned char *)&def_temp)[3] = (M)[4]; ((unsigned char *)&def_temp)[4] = (M)[3]; ((unsigned char *)&def_temp)[5] = (M)[2]; ((unsigned char *)&def_temp)[6] = (M)[1]; ((unsigned char *)&def_temp)[7] = (M)[0]; (*V) = def_temp; } static inline void ushortget(uint16_t *V, const unsigned char *pM) { *V = (uint16_t)(((uint16_t)((unsigned char)(pM)[1])) + ((uint16_t)((uint16_t)(pM)[0]) << 8)); } static inline void shortget(int16_t *V, const unsigned char *pM) { *V = (short)(((short)((unsigned char)(pM)[1])) + ((short)((short)(pM)[0]) << 8)); } static inline void longget(int32_t *V, const unsigned char *pM) { int32_t def_temp; ((unsigned char *)&def_temp)[0] = (pM)[0]; ((unsigned char *)&def_temp)[1] = (pM)[1]; ((unsigned char *)&def_temp)[2] = (pM)[2]; ((unsigned char *)&def_temp)[3] = (pM)[3]; (*V) = def_temp; } static inline void ulongget(uint32_t *V, const unsigned char *pM) { uint32_t def_temp; ((unsigned char *)&def_temp)[0] = (pM)[0]; ((unsigned char *)&def_temp)[1] = (pM)[1]; ((unsigned char *)&def_temp)[2] = (pM)[2]; ((unsigned char *)&def_temp)[3] = (pM)[3]; (*V) = def_temp; } static inline void shortstore(unsigned char *T, int16_t A) { uint def_temp = (uint)(A); *(((unsigned char *)T) + 1) = (unsigned char)(def_temp); *(((unsigned char *)T) + 0) = (unsigned char)(def_temp >> 8); } static inline void longstore(unsigned char *T, int32_t A) { *(((unsigned char *)T) + 3) = ((A)); *(((unsigned char *)T) + 2) = (((A) >> 8)); *(((unsigned char *)T) + 1) = (((A) >> 16)); *(((unsigned char *)T) + 0) = (((A) >> 24)); } static inline void floatget(float *V, const unsigned char *M) { memcpy(V, (M), sizeof(float)); } static inline void floatstore(unsigned char *T, float V) { memcpy((T), (&V), sizeof(float)); } static inline void doubleget(double *V, const unsigned char *M) { memcpy(V, (M), sizeof(double)); } static inline void doublestore(unsigned char *T, double V) { memcpy((T), &V, sizeof(double)); } static inline void longlongget(int64_t *V, const unsigned char *M) { memcpy(V, (M), sizeof(uint64_t)); } static inline void longlongstore(unsigned char *T, int64_t V) { memcpy((T), &V, sizeof(uint64_t)); } static inline int32_t sint3korr(const unsigned char *p) { return ((int32_t)(((p[2]) & 128) ? (((uint32_t)255L << 24) | (((uint32_t)p[2]) << 16) | (((uint32_t)p[1]) << 8) | ((uint32_t)p[0])) : (((uint32_t)p[2]) << 16) | (((uint32_t)p[1]) << 8) | ((uint32_t)p[0]))); } static inline uint32_t uint3korr(const unsigned char *p) { return (uint32_t)(((uint32_t)(p[0])) + (((uint32_t)(p[1])) << 8) + (((uint32_t)(p[2])) << 16)); } static inline void int3store(unsigned char *p, uint32_t x) { *(p) = (unsigned char)(x); *(p + 1) = (unsigned char)(x >> 8); *(p + 2) = (unsigned char)(x >> 16); } #else # error "unknown byte order" #endif // length of buffer needed to store this number [1, 3, 4, 9]. static inline unsigned int get_length_size(unsigned long long num) { if (num < (unsigned long long)252LL) return 1; if (num < (unsigned long long)65536LL) return 3; if (num < (unsigned long long)16777216LL) return 4; return 9; } #ifdef __cplusplus extern "C" { #endif // decode encoded length integer within *end, move pos forward int decode_length_safe(unsigned long long *res, const unsigned char **pos, const unsigned char *end); // decode encoded length string within *end, move pos forward int decode_string(const unsigned char **str, unsigned long long *len, const unsigned char **pos, const unsigned char *end); #ifdef __cplusplus } #endif #endif workflow-0.11.8/src/protocol/mysql_parser.c000066400000000000000000000306521476003635400210060ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) */ #include #include "mysql_types.h" #include "mysql_byteorder.h" #include "mysql_parser.h" static int parse_base_packet(const void *buf, size_t len, mysql_parser_t *parser); static int parse_error_packet(const void *buf, size_t len, mysql_parser_t *parser); static int parse_ok_packet(const void *buf, size_t len, mysql_parser_t *parser); static int parse_eof_packet(const void *buf, size_t len, mysql_parser_t *parser); static int parse_field_eof_packet(const void *buf, size_t len, mysql_parser_t *parser); static int parse_field_count(const void *buf, size_t len, mysql_parser_t *parser); static int parse_column_def_packet(const void *buf, size_t len, mysql_parser_t *parser); static int parse_local_inline(const void *buf, size_t len, mysql_parser_t *parser); static int parse_row_packet(const void *buf, size_t len, mysql_parser_t *parser); void mysql_parser_init(mysql_parser_t *parser) { parser->offset = 0; parser->cmd = MYSQL_COM_QUERY; parser->packet_type = MYSQL_PACKET_OTHER; parser->parse = parse_base_packet; parser->result_set_count = 0; INIT_LIST_HEAD(&parser->result_set_list); } void mysql_parser_deinit(mysql_parser_t *parser) { struct __mysql_result_set *result_set; struct list_head *pos, *tmp; int i; list_for_each_safe(pos, tmp, &parser->result_set_list) { result_set = list_entry(pos, struct __mysql_result_set, list); list_del(pos); if (result_set->field_count) { for (i = 0; i < result_set->field_count; i++) free(result_set->fields[i]); free(result_set->fields); } free(result_set); } } int mysql_parser_parse(const void *buf, size_t len, mysql_parser_t *parser) { // const char *end = (const char *)buf + len; int ret; do { ret = parser->parse(buf, len, parser); if (ret < 0) return ret; if (ret > 0 && parser->offset != len) return -2; } while (parser->offset < len); return ret; } void mysql_parser_get_net_state(const char **net_state_str, size_t *net_state_len, mysql_parser_t *parser) { *net_state_str = (const char *)parser->buf + parser->net_state_offset; *net_state_len = MYSQL_STATE_LENGTH; } void mysql_parser_get_err_msg(const char **err_msg_str, size_t *err_msg_len, mysql_parser_t *parser) { if (parser->err_msg_offset == (size_t)-1 && parser->err_msg_len == 0) { *err_msg_str = MYSQL_STATE_DEFAULT; *err_msg_len = MYSQL_STATE_LENGTH; } else { *err_msg_str = (const char *)parser->buf + parser->err_msg_offset; *err_msg_len = parser->err_msg_len; } } static int parse_base_packet(const void *buf, size_t len, mysql_parser_t *parser) { const unsigned char *p = (const unsigned char *)buf + parser->offset; switch (*p) { // OK PACKET case MYSQL_PACKET_HEADER_OK: parser->parse = parse_ok_packet; break; // ERR PACKET case MYSQL_PACKET_HEADER_ERROR: parser->parse = parse_error_packet; break; // EOF PACKET case MYSQL_PACKET_HEADER_EOF: parser->parse = parse_eof_packet; break; // LOCAL INFILE PACKET case MYSQL_PACKET_HEADER_NULL: // if (field_count == -1) parser->parse = parse_local_inline; break; default: parser->parse = parse_field_count; break; } return 0; } // 1:0xFF|2:err_no|1:#|5:server_state|0-512:err_msg static int parse_error_packet(const void *buf, size_t len, mysql_parser_t *parser) { const unsigned char *p = (const unsigned char *)buf + parser->offset; const unsigned char *buf_end = (const unsigned char *)buf + len; if (p + 9 > buf_end) return -2; parser->error = uint2korr(p + 1); p += 3; if (*p == '#') { p += 1; parser->net_state_offset = p - (const unsigned char *)buf; p += MYSQL_STATE_LENGTH; size_t msg_len = len - parser->offset - 9; parser->err_msg_offset = p - (const unsigned char *)buf; parser->err_msg_len = msg_len; } else { parser->err_msg_offset = (size_t)-1; parser->err_msg_len = 0; } parser->offset = len; parser->packet_type = MYSQL_PACKET_ERROR; parser->buf = buf; return 1; } // 1:0x00|1-9:affect_row|1-9:insert_id|2:server_status|2:warning_count|0-n:server_msg static int parse_ok_packet(const void *buf, size_t len, mysql_parser_t *parser) { const unsigned char *p = (const unsigned char *)buf + parser->offset; const unsigned char *buf_end = (const unsigned char *)buf + len; unsigned long long affected_rows, insert_id, info_len; const unsigned char *str; struct __mysql_result_set *result_set; unsigned int warning_count; int server_status; p += 1;// 0x00 if (decode_length_safe(&affected_rows, &p, buf_end) <= 0) return -2; if (decode_length_safe(&insert_id, &p, buf_end) <= 0) return -2; if (p + 4 > buf_end) return -2; server_status = uint2korr(p); p += 2; warning_count = uint2korr(p); p += 2; if (p != buf_end) { if (decode_string(&str, &info_len, &p, buf_end) == 0) return -2; if (p != buf_end) { if (server_status & MYSQL_SERVER_SESSION_STATE_CHANGED) { const unsigned char *tmp_str; unsigned long long tmp_len; if (decode_string(&tmp_str, &tmp_len, &p, buf_end) == 0) return -2; } else return -2; } } else { str = p; info_len = 0; } result_set = (struct __mysql_result_set *)malloc(sizeof(struct __mysql_result_set)); if (result_set == NULL) return -1; result_set->info_offset = str - (const unsigned char *)buf; result_set->info_len = info_len; result_set->affected_rows = (affected_rows == ~0ULL) ? 0 : affected_rows; result_set->insert_id = (insert_id == ~0ULL) ? 0 : insert_id; result_set->server_status = server_status; result_set->warning_count = warning_count; result_set->type = MYSQL_PACKET_OK; result_set->field_count = 0; list_add_tail(&result_set->list, &parser->result_set_list); parser->current_result_set = result_set; parser->result_set_count++; parser->packet_type = MYSQL_PACKET_OK; parser->buf = buf; parser->offset = p - (const unsigned char *)buf; if (server_status & MYSQL_SERVER_MORE_RESULTS_EXIST) { parser->parse = parse_base_packet; return 0; } return 1; } // 1:0xfe|2:warnings|2:status_flag static int parse_eof_packet(const void *buf, size_t len, mysql_parser_t *parser) { const unsigned char *p = (const unsigned char *)buf + parser->offset; const unsigned char *buf_end = (const unsigned char *)buf + len; if (p + 5 > buf_end) return -2; parser->offset += 5; parser->packet_type = MYSQL_PACKET_EOF; parser->buf = buf; int status_flag = uint2korr(p + 3); if (status_flag & MYSQL_SERVER_MORE_RESULTS_EXIST) { parser->parse = parse_base_packet; return 0; } return 1; } static int parse_field_eof_packet(const void *buf, size_t len, mysql_parser_t *parser) { const unsigned char *p = (const unsigned char *)buf + parser->offset; const unsigned char *buf_end = (const unsigned char *)buf + len; if (p + 5 > buf_end) return -2; parser->offset += 5; parser->current_result_set->rows_begin_offset = parser->offset; parser->parse = parse_row_packet; return 0; } //raw file data static int parse_local_inline(const void *buf, size_t len, mysql_parser_t *parser) { parser->local_inline_offset = parser->offset; parser->local_inline_length = len - parser->offset; parser->offset = len; parser->packet_type = MYSQL_PACKET_LOCAL_INLINE; parser->buf = buf; return 1; } // for each field: // NULL as 0xfb, or a length-encoded-string static int parse_row_packet(const void *buf, size_t len, mysql_parser_t *parser) { const unsigned char *p = (const unsigned char *)buf + parser->offset; const unsigned char *buf_end = (const unsigned char *)buf + len; unsigned long long cell_len; const unsigned char *cell_data; size_t i; if (*p == MYSQL_PACKET_HEADER_ERROR) { parser->parse = parse_error_packet; return 0; } if (*p == MYSQL_PACKET_HEADER_EOF) { parser->parse = parse_eof_packet; parser->current_result_set->rows_end_offset = parser->offset; return 0; } for (i = 0; i < parser->current_result_set->field_count; i++) { if (*p == MYSQL_PACKET_HEADER_NULL) { p++; } else { if (decode_string(&cell_data, &cell_len, &p, buf_end) == 0) break; } } if (i != parser->current_result_set->field_count) return -2; parser->current_result_set->row_count++; parser->offset = p - (const unsigned char *)buf; return 0; } static int parse_field_count(const void *buf, size_t len, mysql_parser_t *parser) { const unsigned char *p = (const unsigned char *)buf + parser->offset; const unsigned char *buf_end = (const unsigned char *)buf + len; unsigned long long field_count; struct __mysql_result_set *result_set; if (decode_length_safe(&field_count, &p, buf_end) <= 0) return -2; field_count = (field_count == ~0ULL) ? 0 : field_count; if (field_count) { result_set = (struct __mysql_result_set *)malloc(sizeof (struct __mysql_result_set)); if (result_set == NULL) return -1; result_set->fields = (mysql_field_t **)calloc(field_count, sizeof (mysql_field_t *)); if (result_set->fields == NULL) { free(result_set); return -1; } result_set->field_count = field_count; result_set->row_count = 0; result_set->type = MYSQL_PACKET_GET_RESULT; list_add_tail(&result_set->list, &parser->result_set_list); parser->current_result_set = result_set; parser->current_field_count = 0; parser->result_set_count++; parser->packet_type = MYSQL_PACKET_GET_RESULT; parser->parse = parse_column_def_packet; parser->offset = p - (const unsigned char *)buf; } else { parser->parse = parse_ok_packet; } return 0; } // COLUMN DEFINATION PACKET. for one field: (after protocol 41) // str:catalog|str:db|str:table|str:org_table|str:name|str:org_name| // 2:charsetnr|4:length|1:type|2:flags|1:decimals|1:0x00|1:0x00|n:str(if COM_FIELD_LIST) static int parse_column_def_packet(const void *buf, size_t len, mysql_parser_t *parser) { const unsigned char *p = (const unsigned char *)buf + parser->offset; const unsigned char *buf_end = (const unsigned char *)buf + len; int flag = 0; const unsigned char *str; unsigned long long str_len; mysql_field_t *field = (mysql_field_t *)malloc(sizeof(mysql_field_t)); if (!field) return -1; do { if (decode_string(&str, &str_len, &p, buf_end) == 0) break; field->catalog_offset = str - (const unsigned char *)buf; field->catalog_length = str_len; if (decode_string(&str, &str_len, &p, buf_end) == 0) break; field->db_offset = str - (const unsigned char *)buf; field->db_length = str_len; if (decode_string(&str, &str_len, &p, buf_end) == 0) break; field->table_offset = str - (const unsigned char *)buf; field->table_length = str_len; if (decode_string(&str, &str_len, &p, buf_end) == 0) break; field->org_table_offset = str - (const unsigned char *)buf; field->org_table_length = str_len; if (decode_string(&str, &str_len, &p, buf_end) == 0) break; field->name_offset = str - (const unsigned char *)buf; field->name_length = str_len; if (decode_string(&str, &str_len, &p, buf_end) == 0) break; field->org_name_offset = str - (const unsigned char *)buf; field->org_name_length = str_len; // the rest needs at least 13 if (p + 13 > buf_end) break; p++; // length of the following fields (always 0x0c) field->charsetnr = uint2korr(p); field->length = uint4korr(p + 2); field->data_type = *(p + 6); field->flags = uint2korr(p + 7); field->decimals = (int)p[9]; p += 12; // if is COM_FIELD_LIST, the rest is a string // 0x03 for COM_QUERY if (parser->cmd == MYSQL_COM_FIELD_LIST) { if (decode_string(&str, &str_len, &p, buf_end) == 0) break; field->def_offset = str - (const unsigned char *)buf; field->def_length = str_len; } else { field->def_offset = (size_t)-1; field->def_length = 0; } flag = 1; } while (0); if (flag == 0) { free(field); return -2; } //parser->fields.emplace_back(std::move(field)); parser->current_result_set->fields[parser->current_field_count] = field; parser->offset = p - (const unsigned char *)buf; if (++parser->current_field_count == parser->current_result_set->field_count) parser->parse = parse_field_eof_packet; return 0; } workflow-0.11.8/src/protocol/mysql_parser.h000066400000000000000000000115111476003635400210040ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) */ #ifndef _MYSQL_PARSER_H_ #define _MYSQL_PARSER_H_ #include #include "list.h" // the first byte in response message // from 1 to 0xfa means result_set field or data_row // NULL is sent as 0xfb, will be treated as LOCAL_INLINE // MySQL MESSAGE STATUS enum { MYSQL_PACKET_HEADER_OK = 0, MYSQL_PACKET_HEADER_NULL = 251, //0xfb MYSQL_PACKET_HEADER_EOF = 254, //0xfe MYSQL_PACKET_HEADER_ERROR = 255, //0xff }; typedef struct __mysql_field { size_t name_offset; /* Name of column */ size_t org_name_offset; /* Original column name, if an alias */ size_t table_offset; /* Table of column if column was a field */ size_t org_table_offset; /* Org table name, if table was an alias */ size_t db_offset; /* Database for table */ size_t catalog_offset; /* Catalog for table */ size_t def_offset; /* Default value (set by mysql_list_fields) */ int length; /* Width of column (create length) */ int name_length; int org_name_length; int table_length; int org_table_length; int db_length; int catalog_length; int def_length; int flags; /* Div flags */ int decimals; /* Number of decimals in field */ int charsetnr; /* Character set */ int data_type; /* Type of field. See mysql_types.h for types */ // void *extension; } mysql_field_t; struct __mysql_result_set { struct list_head list; int type; int server_status; int field_count; int row_count; size_t rows_begin_offset; size_t rows_end_offset; mysql_field_t **fields; unsigned long long affected_rows; unsigned long long insert_id; int warning_count; size_t info_offset; int info_len; }; typedef struct __mysql_result_set_cursor { const struct list_head *head; const struct list_head *current; } mysql_result_set_cursor_t; typedef struct __mysql_parser { size_t offset; int cmd; int packet_type; int (*parse)(const void *, size_t, struct __mysql_parser *); size_t net_state_offset; // err packet server_state size_t err_msg_offset; // -1 for default int err_msg_len; // -1 for default size_t local_inline_offset; // local inline file name int local_inline_length; const void *buf; int error; int result_set_count; struct list_head result_set_list; struct __mysql_result_set *current_result_set; int current_field_count; } mysql_parser_t; #ifdef __cplusplus extern "C" { #endif void mysql_parser_init(mysql_parser_t *parser); void mysql_parser_deinit(mysql_parser_t *parser); void mysql_parser_get_info(const char **info_str, size_t *info_len, mysql_parser_t *parser); void mysql_parser_get_net_state(const char **net_state_str, size_t *net_state_len, mysql_parser_t *parser); void mysql_parser_get_err_msg(const char **err_msg_str, size_t *err_msg_len, mysql_parser_t *parser); // if append check get 0, don`t need to parse() // if append check get 1, parse and tell them if this is all the package // // ret: 1: this ResultSet is received finished // 0: this ResultSet is not recieved finished // -1: system error // -2: bad message error int mysql_parser_parse(const void *buf, size_t len, mysql_parser_t *parser); #ifdef __cplusplus } #endif static inline void mysql_parser_set_command(int cmd, mysql_parser_t *parser) { parser->cmd = cmd; } static inline void mysql_parser_get_local_inline(const char **local_inline_name, size_t *local_inline_len, mysql_parser_t *parser) { *local_inline_name = (const char *)parser->buf + parser->local_inline_offset; *local_inline_len = parser->local_inline_length; } static inline void mysql_result_set_cursor_init(mysql_result_set_cursor_t *cursor, mysql_parser_t *parser) { cursor->head = &parser->result_set_list; cursor->current = cursor->head; } static inline void mysql_result_set_cursor_rewind(mysql_result_set_cursor_t *cursor) { cursor->current = cursor->head; } static inline void mysql_result_set_cursor_deinit(mysql_result_set_cursor_t *cursor) { } static inline int mysql_result_set_cursor_next(struct __mysql_result_set **result_set, mysql_result_set_cursor_t *cursor) { if (cursor->current->next != cursor->head) { cursor->current = cursor->current->next; *result_set = list_entry(cursor->current, struct __mysql_result_set, list); return 0; } return 1; } #endif workflow-0.11.8/src/protocol/mysql_stream.c000066400000000000000000000046611476003635400210060ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include "mysql_stream.h" #define MAX(x, y) ((x) >= (y) ? (x) : (y)) static int __mysql_stream_write_payload(const void *buf, size_t *n, mysql_stream_t *stream); static int __mysql_stream_write_head(const void *buf, size_t *n, mysql_stream_t *stream) { void *p = &stream->head[4 - stream->head_left]; if (*n < stream->head_left) { memcpy(p, buf, *n); stream->head_left -= *n; return 0; } memcpy(p, buf, stream->head_left); stream->payload_length = (stream->head[2] << 16) + (stream->head[1] << 8) + stream->head[0]; stream->payload_left = stream->payload_length; stream->sequence_id = stream->head[3]; if (stream->bufsize < stream->length + stream->payload_left) { size_t new_size = MAX(2048, 2 * stream->bufsize); void *new_base; while (new_size < stream->length + stream->payload_left) new_size *= 2; new_base = realloc(stream->buf, new_size); if (!new_base) return -1; stream->buf = new_base; stream->bufsize = new_size; } *n = stream->head_left; stream->write = __mysql_stream_write_payload; return 0; } static int __mysql_stream_write_payload(const void *buf, size_t *n, mysql_stream_t *stream) { char *p = (char *)stream->buf + stream->length; if (*n < stream->payload_left) { memcpy(p, buf, *n); stream->length += *n; stream->payload_left -= *n; return 0; } memcpy(p, buf, stream->payload_left); stream->length += stream->payload_left; *n = stream->payload_left; stream->head_left = 4; stream->write = __mysql_stream_write_head; return stream->payload_length != (1 << 24) - 1; } void mysql_stream_init(mysql_stream_t *stream) { stream->head_left = 4; stream->sequence_id = 0; stream->buf = NULL; stream->length = 0; stream->bufsize = 0; stream->write = __mysql_stream_write_head; } workflow-0.11.8/src/protocol/mysql_stream.h000066400000000000000000000030701476003635400210040ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _MYSQL_STREAM_H_ #define _MYSQL_STREAM_H_ #include typedef struct __mysql_stream { unsigned char head[4]; unsigned char head_left; unsigned char sequence_id; int payload_length; int payload_left; void *buf; size_t length; size_t bufsize; int (*write)(const void *, size_t *, struct __mysql_stream *); } mysql_stream_t; #ifdef __cplusplus extern "C" { #endif void mysql_stream_init(mysql_stream_t *stream); #ifdef __cplusplus } #endif static inline int mysql_stream_write(const void *buf, size_t *n, mysql_stream_t *stream) { return stream->write(buf, n, stream); } static inline int mysql_stream_get_seq(mysql_stream_t *stream) { return stream->sequence_id; } static inline void mysql_stream_get_buf(const void **buf, size_t *length, mysql_stream_t *stream) { *buf = stream->buf; *length = stream->length; } static inline void mysql_stream_deinit(mysql_stream_t *stream) { free(stream->buf); } #endif workflow-0.11.8/src/protocol/mysql_types.h000066400000000000000000000102651476003635400206610ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com) */ #ifndef _MYSQL_TYPES_H_ #define _MYSQL_TYPES_H_ #define MYSQL_STATE_LENGTH 5 #define MYSQL_STATE_DEFAULT "HY000" #define MYSQL_SERVER_MORE_RESULTS_EXIST 0x0008 #define MYSQL_SERVER_SESSION_STATE_CHANGED 0x4000 enum { MYSQL_COM_SLEEP, MYSQL_COM_QUIT, MYSQL_COM_INIT_DB, MYSQL_COM_QUERY, MYSQL_COM_FIELD_LIST, MYSQL_COM_CREATE_DB, MYSQL_COM_DROP_DB, MYSQL_COM_REFRESH, MYSQL_COM_DEPRECATED_1, MYSQL_COM_STATISTICS, MYSQL_COM_PROCESS_INFO, MYSQL_COM_CONNECT, MYSQL_COM_PROCESS_KILL, MYSQL_COM_DEBUG, MYSQL_COM_PING, MYSQL_COM_TIME, MYSQL_COM_DELAYED_INSERT, MYSQL_COM_CHANGE_USER, MYSQL_COM_BINLOG_DUMP, MYSQL_COM_TABLE_DUMP, MYSQL_COM_CONNECT_OUT, MYSQL_COM_REGISTER_SLAVE, MYSQL_COM_STMT_PREPARE, MYSQL_COM_STMT_EXECUTE, MYSQL_COM_STMT_SEND_LONG_DATA, MYSQL_COM_STMT_CLOSE, MYSQL_COM_STMT_RESET, MYSQL_COM_SET_OPTION, MYSQL_COM_STMT_FETCH, MYSQL_COM_DAEMON, MYSQL_COM_BINLOG_DUMP_GTID, MYSQL_COM_RESET_CONNECTION, MYSQL_COM_CLONE, MYSQL_COM_END }; // MySQL packet type enum { MYSQL_PACKET_OTHER = 0, MYSQL_PACKET_OK, MYSQL_PACKET_NULL, MYSQL_PACKET_EOF, MYSQL_PACKET_ERROR, MYSQL_PACKET_GET_RESULT, MYSQL_PACKET_LOCAL_INLINE, }; // MySQL cursor status enum { MYSQL_STATUS_NOT_INIT = 0, MYSQL_STATUS_OK, MYSQL_STATUS_GET_RESULT, MYSQL_STATUS_ERROR, MYSQL_STATUS_END, }; // Column types for MySQL enum { MYSQL_TYPE_DECIMAL = 0, MYSQL_TYPE_TINY, MYSQL_TYPE_SHORT, MYSQL_TYPE_LONG, MYSQL_TYPE_FLOAT, MYSQL_TYPE_DOUBLE, MYSQL_TYPE_NULL, MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_LONGLONG, MYSQL_TYPE_INT24, MYSQL_TYPE_DATE, MYSQL_TYPE_TIME, MYSQL_TYPE_DATETIME, MYSQL_TYPE_YEAR, MYSQL_TYPE_NEWDATE, // Internal to MySQL. Not used in protocol MYSQL_TYPE_VARCHAR, MYSQL_TYPE_BIT, MYSQL_TYPE_TIMESTAMP2, MYSQL_TYPE_DATETIME2, // Internal to MySQL. Not used in protocol MYSQL_TYPE_TIME2, // Internal to MySQL. Not used in protocol MYSQL_TYPE_TYPED_ARRAY = 244, // Used for replication only MYSQL_TYPE_JSON = 245, MYSQL_TYPE_NEWDECIMAL = 246, MYSQL_TYPE_ENUM = 247, MYSQL_TYPE_SET = 248, MYSQL_TYPE_TINY_BLOB = 249, MYSQL_TYPE_MEDIUM_BLOB = 250, MYSQL_TYPE_LONG_BLOB = 251, MYSQL_TYPE_BLOB = 252, MYSQL_TYPE_VAR_STRING = 253, MYSQL_TYPE_STRING = 254, MYSQL_TYPE_GEOMETRY = 255 }; static inline const char *datatype2str(int data_type) { switch (data_type) { case MYSQL_TYPE_BIT: return "BIT"; case MYSQL_TYPE_BLOB: return "BLOB"; case MYSQL_TYPE_DATE: return "DATE"; case MYSQL_TYPE_DATETIME: return "DATETIME"; case MYSQL_TYPE_NEWDECIMAL: return "NEWDECIMAL"; case MYSQL_TYPE_DECIMAL: return "DECIMAL"; case MYSQL_TYPE_DOUBLE: return "DOUBLE"; case MYSQL_TYPE_ENUM: return "ENUM"; case MYSQL_TYPE_FLOAT: return "FLOAT"; case MYSQL_TYPE_GEOMETRY: return "GEOMETRY"; case MYSQL_TYPE_INT24: return "INT24"; case MYSQL_TYPE_JSON: return "JSON"; case MYSQL_TYPE_LONG: return "LONG"; case MYSQL_TYPE_LONGLONG: return "LONGLONG"; case MYSQL_TYPE_LONG_BLOB: return "LONG_BLOB"; case MYSQL_TYPE_MEDIUM_BLOB: return "MEDIUM_BLOB"; case MYSQL_TYPE_NEWDATE: return "NEWDATE"; case MYSQL_TYPE_NULL: return "NULL"; case MYSQL_TYPE_SET: return "SET"; case MYSQL_TYPE_SHORT: return "SHORT"; case MYSQL_TYPE_STRING: return "STRING"; case MYSQL_TYPE_TIME: return "TIME"; case MYSQL_TYPE_TIMESTAMP: return "TIMESTAMP"; case MYSQL_TYPE_TINY: return "TINY"; case MYSQL_TYPE_TINY_BLOB: return "TINY_BLOB"; case MYSQL_TYPE_VAR_STRING: return "VAR_STRING"; case MYSQL_TYPE_YEAR: return "YEAR"; default: return "?-unknown-?"; } } #endif workflow-0.11.8/src/protocol/redis_parser.c000066400000000000000000000226651476003635400207540ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) Liu Kai (liukaidx@sogou-inc.com) */ #include #include #include #include "list.h" #include "redis_parser.h" #define MIN(x, y) ((x) <= (y) ? (x) : (y)) #define MAX(x, y) ((x) >= (y) ? (x) : (y)) #define REDIS_MSGBUF_INIT_SIZE 8 #define REDIS_REPLY_DEPTH_LIMIT 64 #define REDIS_ARRAY_SIZE_LIMIT (4 * 1024 * 1024) enum { //REDIS_PARSE_INIT = 0, REDIS_GET_CMD = 1, REDIS_GET_CR, REDIS_GET_LF, REDIS_UNTIL_CRLF, REDIS_GET_NCHAR, REDIS_PARSE_END }; struct __redis_read_record { struct list_head list; redis_reply_t *reply; }; void redis_reply_deinit(redis_reply_t *reply) { size_t i; for (i = 0; i < reply->elements; i++) { redis_reply_deinit(reply->element[i]); free(reply->element[i]); } free(reply->element); } static redis_reply_t **__redis_create_array(size_t size, redis_reply_t *reply) { size_t elements = 0; redis_reply_t **element = (redis_reply_t **)malloc(size * sizeof (void *)); if (element) { size_t i; for (i = 0; i < size; i++) { element[i] = (redis_reply_t *)malloc(sizeof (redis_reply_t)); if (element[i]) { redis_reply_init(element[i]); elements++; continue; } break; } if (elements == size) return element; while (elements > 0) free(element[--elements]); free(element); } return NULL; } int redis_reply_set_array(size_t size, redis_reply_t *reply) { redis_reply_t **element = __redis_create_array(size, reply); if (element == NULL) return -1; redis_reply_deinit(reply); reply->element = element; reply->elements = size; reply->type = REDIS_REPLY_TYPE_ARRAY; return 0; } static int __redis_parse_cmd(const char ch, redis_parser_t *parser) { switch (ch) { case '+': case '-': case ':': case '$': case '*': parser->cmd = ch; parser->status = REDIS_UNTIL_CRLF; parser->findidx = parser->msgidx; return 0; } return -2; } static int __redis_parse_cr(const char ch, redis_parser_t *parser) { if (ch != '\r') return -2; parser->status = REDIS_GET_LF; return 0; } static int __redis_parse_lf(const char ch, redis_parser_t *parser) { if (ch != '\n') return -2; return 1; } static int __redis_parse_line(redis_parser_t *parser) { char *str = parser->msgbuf + parser->msgidx; size_t slen = parser->findidx - parser->msgidx; char data[32]; int i, n; const char *offset = (const char *)parser->msgidx; struct __redis_read_record *node; parser->msgidx = parser->findidx + 2; switch (parser->cmd) { case '+': redis_reply_set_status(offset, slen, parser->cur); return 1; case '-': redis_reply_set_error(offset, slen, parser->cur); return 1; case ':': if (slen == 0 || slen > 30) return -2; memcpy(data, str, slen); data[slen] = '\0'; redis_reply_set_integer(atoll(data), parser->cur); return 1; case '$': n = atoi(str); if (n < 0) { redis_reply_set_null(parser->cur); return 1; } else if (n == 0) { /* "-0" not acceptable. */ if (!isdigit(*str)) return -2; redis_reply_set_string(offset, 0, parser->cur); parser->status = REDIS_GET_CR; return 0; } parser->nchar = n; parser->status = REDIS_GET_NCHAR; return 0; case '*': n = atoi(str); if (n < 0) { redis_reply_set_null(parser->cur); return 1; } if (n == 0 && !isdigit(*str)) return -2; if (n > REDIS_ARRAY_SIZE_LIMIT) return -2; parser->nleft += n; if (redis_reply_set_array(n, parser->cur) < 0) return -1; if (n == 0) return 1; parser->nleft--; for (i = 0; i < n - 1; i++) { node = (struct __redis_read_record *)malloc(sizeof *node); if (!node) return -1; node->reply = parser->cur->element[n - 1 - i]; list_add(&node->list, &parser->read_list); } parser->cur = parser->cur->element[0]; parser->status = REDIS_GET_CMD; return 0; } return -2; } static int __redis_parse_crlf(redis_parser_t *parser) { char *buf = parser->msgbuf; for (; parser->findidx + 1 < parser->msgsize; parser->findidx++) { if (buf[parser->findidx] == '\r' && buf[parser->findidx + 1] == '\n') return __redis_parse_line(parser); } return 2; } static int __redis_parse_nchar(redis_parser_t *parser) { if (parser->nchar <= parser->msgsize - parser->msgidx) { redis_reply_set_string((const char *)parser->msgidx, parser->nchar, parser->cur); parser->msgidx += parser->nchar; parser->status = REDIS_GET_CR; return 0; } return 2; } //-1 error | 0 continue | 1 finish-one | 2 not-enough static int __redis_parser_forward(redis_parser_t *parser) { char *buf = parser->msgbuf; if (parser->msgidx >= parser->msgsize) return 2; switch (parser->status) { case REDIS_GET_CMD: return __redis_parse_cmd(buf[parser->msgidx++], parser); case REDIS_GET_CR: return __redis_parse_cr(buf[parser->msgidx++], parser); case REDIS_GET_LF: return __redis_parse_lf(buf[parser->msgidx++], parser); case REDIS_UNTIL_CRLF: return __redis_parse_crlf(parser); case REDIS_GET_NCHAR: return __redis_parse_nchar(parser); } return -2; } void redis_parser_init(redis_parser_t *parser) { redis_reply_init(&parser->reply); parser->parse_succ = 0; parser->msgbuf = NULL; parser->msgsize = 0; parser->bufsize = 0; //parser->status = REDIS_PARSE_INIT; //parser->nleft = 0; parser->status = REDIS_GET_CMD; parser->nleft = 1; parser->cur = &parser->reply; INIT_LIST_HEAD(&parser->read_list); parser->msgidx = 0; parser->cmd = '\0'; parser->nchar = 0; parser->findidx = 0; } void redis_parser_deinit(redis_parser_t *parser) { struct list_head *pos, *tmp; struct __redis_read_record *next; list_for_each_safe(pos, tmp, &parser->read_list) { next = list_entry(pos, struct __redis_read_record, list); list_del(pos); free(next); } redis_reply_deinit(&parser->reply); free(parser->msgbuf); } static int __redis_parse_done(redis_reply_t *reply, char *buf, int depth) { size_t i; if (depth == REDIS_REPLY_DEPTH_LIMIT) return -2; switch (reply->type) { case REDIS_REPLY_TYPE_INTEGER: break; case REDIS_REPLY_TYPE_ARRAY: for (i = 0; i < reply->elements; i++) { if (__redis_parse_done(reply->element[i], buf, depth + 1) < 0) return -2; } break; case REDIS_REPLY_TYPE_STATUS: case REDIS_REPLY_TYPE_ERROR: case REDIS_REPLY_TYPE_STRING: reply->str = buf + (size_t)reply->str; break; } return 1; } static int __redis_split_inline_command(redis_parser_t *parser) { char *msg = parser->msgbuf; char *end = msg + parser->msgsize; size_t arr_size = 0; redis_reply_t **ele; char *cur; int ret; while (msg != end) { while (msg != end && isspace(*msg)) msg++; if (msg == end) break; arr_size++; while (msg != end && !isspace(*msg)) msg++; } if (arr_size == 0) { parser->msgsize = 0; parser->msgidx = 0; return 0; } ret = redis_reply_set_array(arr_size, &parser->reply); if (ret < 0) return ret; ele = parser->reply.element; msg = parser->msgbuf; while (msg != end) { while (msg != end && isspace(*msg)) msg++; if (msg == end) break; cur = msg; while (cur != end && !isspace(*cur)) cur++; redis_reply_set_string(msg, cur - msg, *ele); msg = cur; ele++; } parser->status = REDIS_PARSE_END; return 1; } int redis_parser_append_message(const void *buf, size_t *size, redis_parser_t *parser) { size_t msgsize_bak = parser->msgsize; if (parser->status == REDIS_PARSE_END) { *size = 0; return 1; } if (parser->msgsize + *size > parser->bufsize) { size_t new_size = MAX(REDIS_MSGBUF_INIT_SIZE, 2 * parser->bufsize); void *new_base; while (new_size < parser->msgsize + *size) new_size *= 2; new_base = realloc(parser->msgbuf, new_size); if (!new_base) return -1; parser->msgbuf = (char *)new_base; parser->bufsize = new_size; } memcpy(parser->msgbuf + parser->msgsize, buf, *size); parser->msgsize += *size; if (parser->msgsize > 0 && (isalpha(*parser->msgbuf) || isspace(*parser->msgbuf))) { while (parser->msgidx < parser->msgsize && *(parser->msgbuf + parser->msgidx) != '\n') { parser->msgidx++; } if (parser->msgidx == parser->msgsize) return 0; parser->msgidx++; parser->msgsize = parser->msgidx; *size = parser->msgsize - msgsize_bak; return __redis_split_inline_command(parser); } do { int ret = __redis_parser_forward(parser); if (ret < 0) return ret; if (ret == 1) { struct list_head *lnext = parser->read_list.next; struct __redis_read_record *next; parser->nleft--; if (lnext && lnext != &parser->read_list) { next = list_entry(lnext, struct __redis_read_record, list); parser->cur = next->reply; list_del(lnext); free(next); } if (parser->nleft > 0) parser->status = REDIS_GET_CMD; else { parser->parse_succ = 1; parser->status = REDIS_PARSE_END; } } else if (ret == 2) return 0; } while (parser->status != REDIS_PARSE_END); *size = parser->msgidx - msgsize_bak; return __redis_parse_done(&parser->reply, parser->msgbuf, 0); } workflow-0.11.8/src/protocol/redis_parser.h000066400000000000000000000062071476003635400207530ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #ifndef _REDIS_PARSER_H_ #define _REDIS_PARSER_H_ #include #include "list.h" // redis_parser_t is absolutely same as hiredis-redisReply in memory // If you include hiredis.h, redisReply* can cast to redis_reply_t* safely #define REDIS_REPLY_TYPE_STRING 1 #define REDIS_REPLY_TYPE_ARRAY 2 #define REDIS_REPLY_TYPE_INTEGER 3 #define REDIS_REPLY_TYPE_NIL 4 #define REDIS_REPLY_TYPE_STATUS 5 #define REDIS_REPLY_TYPE_ERROR 6 typedef struct __redis_reply { int type; /* REDIS_REPLY_TYPE_* */ long long integer; /* The integer when type is REDIS_REPLY_TYPE_INTEGER */ size_t len; /* Length of string */ char *str; /* Used for both REDIS_REPLY_TYPE_ERROR and REDIS_REPLY_TYPE_STRING */ size_t elements; /* number of elements, for REDIS_REPLY_TYPE_ARRAY */ struct __redis_reply **element; /* elements vector for REDIS_REPLY_TYPE_ARRAY */ } redis_reply_t; typedef struct __redis_parser { int parse_succ;//check first int status; char *msgbuf; size_t msgsize; size_t bufsize; redis_reply_t *cur; struct list_head read_list; size_t msgidx; size_t findidx; int nleft; int nchar; char cmd; redis_reply_t reply; } redis_parser_t; #ifdef __cplusplus extern "C" { #endif void redis_parser_init(redis_parser_t *parser); void redis_parser_deinit(redis_parser_t *parser); int redis_parser_append_message(const void *buf, size_t *size, redis_parser_t *parser); void redis_reply_deinit(redis_reply_t *reply); int redis_reply_set_array(size_t size, redis_reply_t *reply); #ifdef __cplusplus } #endif static inline void redis_reply_init(redis_reply_t *reply) { reply->type = REDIS_REPLY_TYPE_NIL; reply->integer = 0; reply->len = 0; reply->str = NULL; reply->elements = 0; reply->element = NULL; } static inline void redis_reply_set_string(const char *str, size_t len, redis_reply_t *reply) { reply->type = REDIS_REPLY_TYPE_STRING; reply->len = len; reply->str = (char *)str; } static inline void redis_reply_set_integer(long long intv, redis_reply_t *reply) { reply->type = REDIS_REPLY_TYPE_INTEGER; reply->integer = intv; } static inline void redis_reply_set_null(redis_reply_t *reply) { reply->type = REDIS_REPLY_TYPE_NIL; } static inline void redis_reply_set_error(const char *err, size_t len, redis_reply_t *reply) { reply->type = REDIS_REPLY_TYPE_ERROR; reply->len = len; reply->str = (char *)err; } static inline void redis_reply_set_status(const char *str, size_t len, redis_reply_t *reply) { reply->type = REDIS_REPLY_TYPE_STATUS; reply->len = len; reply->str = (char *)str; } #endif workflow-0.11.8/src/protocol/xmake.lua000066400000000000000000000030731476003635400177260ustar00rootroot00000000000000target("basic_protocol") set_kind("object") add_files("PackageWrapper.cc", "SSLWrapper.cc", "dns_parser.c", "DnsMessage.cc", "DnsUtil.cc", "http_parser.c", "HttpMessage.cc", "HttpUtil.cc") target("mysql_protocol") if has_config("mysql") then add_files("mysql_stream.c", "mysql_parser.c", "mysql_byteorder.c", "MySQLMessage.cc", "MySQLResult.cc", "MySQLUtil.cc") set_kind("object") add_deps("basic_protocol") else set_kind("phony") end target("redis_protocol") if has_config("redis") then add_files("redis_parser.c", "RedisMessage.cc") set_kind("object") add_deps("basic_protocol") else set_kind("phony") end target("protocol") set_kind("object") add_deps("basic_protocol", "mysql_protocol", "redis_protocol") target("kafka_message") if has_config("kafka") then add_files("KafkaMessage.cc") set_kind("object") add_cxxflags("-fno-rtti") add_packages("lz4", "zstd", "zlib", "snappy") else set_kind("phony") end target("kafka_protocol") if has_config("kafka") then set_kind("object") add_files("kafka_parser.c", "KafkaDataTypes.cc", "KafkaResult.cc") add_deps("kafka_message", "protocol") add_packages("zlib", "snappy", "zstd", "lz4") else set_kind("phony") end workflow-0.11.8/src/server/000077500000000000000000000000001476003635400155605ustar00rootroot00000000000000workflow-0.11.8/src/server/CMakeLists.txt000066400000000000000000000003061476003635400203170ustar00rootroot00000000000000cmake_minimum_required(VERSION 3.6) project(server) set(SRC WFServer.cc ) if (NOT MYSQL STREQUAL "n") set(SRC ${SRC} WFMySQLServer.cc ) endif () add_library(${PROJECT_NAME} OBJECT ${SRC}) workflow-0.11.8/src/server/WFDnsServer.h000066400000000000000000000032551476003635400201060ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Liu Kai (liukaidx@sogou-inc.com) */ #ifndef _WFDNSSERVER_H_ #define _WFDNSSERVER_H_ #include "DnsMessage.h" #include "WFServer.h" #include "WFTaskFactory.h" using dns_process_t = std::function; using WFDnsServer = WFServer; static constexpr struct WFServerParams DNS_SERVER_PARAMS_DEFAULT = { .transport_type = TT_UDP, .max_connections = 2000, .peer_response_timeout = 10 * 1000, .receive_timeout = -1, .keep_alive_timeout = 300 * 1000, .request_size_limit = (size_t)-1, .ssl_accept_timeout = 5000, }; template<> inline WFDnsServer::WFServer(dns_process_t proc) : WFServerBase(&DNS_SERVER_PARAMS_DEFAULT), process(std::move(proc)) { } template<> inline CommSession *WFDnsServer::new_session(long long seq, CommConnection *conn) { WFDnsTask *task; task = WFServerTaskFactory::create_dns_task(this, this->process); task->set_keep_alive(this->params.keep_alive_timeout); task->set_receive_timeout(this->params.receive_timeout); task->get_req()->set_size_limit(this->params.request_size_limit); return task; } #endif workflow-0.11.8/src/server/WFHttpServer.h000066400000000000000000000033221476003635400202740ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) */ #ifndef _WFHTTPSERVER_H_ #define _WFHTTPSERVER_H_ #include #include "HttpMessage.h" #include "WFServer.h" #include "WFTaskFactory.h" using http_process_t = std::function; using WFHttpServer = WFServer; static constexpr struct WFServerParams HTTP_SERVER_PARAMS_DEFAULT = { .transport_type = TT_TCP, .max_connections = 2000, .peer_response_timeout = 10 * 1000, .receive_timeout = -1, .keep_alive_timeout = 60 * 1000, .request_size_limit = (size_t)-1, .ssl_accept_timeout = 10 * 1000, }; template<> inline WFHttpServer::WFServer(http_process_t proc) : WFServerBase(&HTTP_SERVER_PARAMS_DEFAULT), process(std::move(proc)) { } template<> inline CommSession *WFHttpServer::new_session(long long seq, CommConnection *conn) { WFHttpTask *task; task = WFServerTaskFactory::create_http_task(this, this->process); task->set_keep_alive(this->params.keep_alive_timeout); task->set_receive_timeout(this->params.receive_timeout); task->get_req()->set_size_limit(this->params.request_size_limit); return task; } #endif workflow-0.11.8/src/server/WFMySQLServer.cc000066400000000000000000000031151476003635400204600ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #include #include "WFMySQLServer.h" WFConnection *WFMySQLServer::new_connection(int accept_fd) { WFConnection *conn = this->WFServer::new_connection(accept_fd); if (conn) { protocol::MySQLHandshakeResponse resp; struct iovec vec[8]; int count; resp.server_set(0x0a, "5.5", 1, (const uint8_t *)"12345678901234567890", 0, 33, 0); count = resp.encode(vec, 8); if (count >= 0) { if (writev(accept_fd, vec, count) >= 0) return conn; } this->delete_connection(conn); } return NULL; } CommSession *WFMySQLServer::new_session(long long seq, CommConnection *conn) { static mysql_process_t empty = [](WFMySQLTask *){ }; WFMySQLTask *task; task = WFServerTaskFactory::create_mysql_task(this, seq ? this->process : empty); task->set_keep_alive(this->params.keep_alive_timeout); task->set_receive_timeout(this->params.receive_timeout); task->get_req()->set_size_limit(this->params.request_size_limit); return task; } workflow-0.11.8/src/server/WFMySQLServer.h000066400000000000000000000030111476003635400203150ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #ifndef _WFMYSQLSERVER_H_ #define _WFMYSQLSERVER_H_ #include #include "MySQLMessage.h" #include "WFServer.h" #include "WFTaskFactory.h" #include "WFConnection.h" using mysql_process_t = std::function; class MySQLServer; static constexpr struct WFServerParams MYSQL_SERVER_PARAMS_DEFAULT = { .transport_type = TT_TCP, .max_connections = 2000, .peer_response_timeout = 10 * 1000, .receive_timeout = -1, .keep_alive_timeout = 28800 * 1000, .request_size_limit = (size_t)-1, .ssl_accept_timeout = 10 * 1000, }; class WFMySQLServer : public WFServer { public: WFMySQLServer(mysql_process_t proc): WFServer(&MYSQL_SERVER_PARAMS_DEFAULT, std::move(proc)) { } protected: virtual WFConnection *new_connection(int accept_fd); virtual CommSession *new_session(long long seq, CommConnection *conn); }; #endif workflow-0.11.8/src/server/WFRedisServer.h000066400000000000000000000025111476003635400204220ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) */ #ifndef _WFREDISSERVER_H_ #define _WFREDISSERVER_H_ #include "RedisMessage.h" #include "WFServer.h" #include "WFTaskFactory.h" using redis_process_t = std::function; using WFRedisServer = WFServer; static constexpr struct WFServerParams REDIS_SERVER_PARAMS_DEFAULT = { .transport_type = TT_TCP, .max_connections = 2000, .peer_response_timeout = 10 * 1000, .receive_timeout = -1, .keep_alive_timeout = 300 * 1000, .request_size_limit = (size_t)-1, .ssl_accept_timeout = 5000, }; template<> inline WFRedisServer::WFServer(redis_process_t proc) : WFServerBase(&REDIS_SERVER_PARAMS_DEFAULT), process(std::move(proc)) { } #endif workflow-0.11.8/src/server/WFServer.cc000066400000000000000000000142731476003635400176010ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com) */ #include #include #include #include #include #include #include #include #include #include "CommScheduler.h" #include "EndpointParams.h" #include "WFConnection.h" #include "WFGlobal.h" #include "WFServer.h" #define PORT_STR_MAX 5 class WFServerConnection : public WFConnection { public: WFServerConnection(std::atomic *conn_count) { this->conn_count = conn_count; } virtual ~WFServerConnection() { (*this->conn_count)--; } private: std::atomic *conn_count; }; int WFServerBase::ssl_ctx_callback(SSL *ssl, int *al, void *arg) { WFServerBase *server = (WFServerBase *)arg; const char *servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); SSL_CTX *ssl_ctx = server->get_server_ssl_ctx(servername); if (!ssl_ctx) return SSL_TLSEXT_ERR_NOACK; if (ssl_ctx != server->get_ssl_ctx()) SSL_set_SSL_CTX(ssl, ssl_ctx); return SSL_TLSEXT_ERR_OK; } SSL_CTX *WFServerBase::new_ssl_ctx(const char *cert_file, const char *key_file) { SSL_CTX *ssl_ctx = WFGlobal::new_ssl_server_ctx(); if (!ssl_ctx) return NULL; if (SSL_CTX_use_certificate_chain_file(ssl_ctx, cert_file) > 0 && SSL_CTX_use_PrivateKey_file(ssl_ctx, key_file, SSL_FILETYPE_PEM) > 0 && SSL_CTX_check_private_key(ssl_ctx) > 0 && SSL_CTX_set_tlsext_servername_callback(ssl_ctx, ssl_ctx_callback) > 0 && SSL_CTX_set_tlsext_servername_arg(ssl_ctx, this) > 0) { return ssl_ctx; } SSL_CTX_free(ssl_ctx); return NULL; } int WFServerBase::init(const struct sockaddr *bind_addr, socklen_t addrlen, const char *cert_file, const char *key_file) { int timeout = this->params.peer_response_timeout; if (this->params.receive_timeout >= 0) { if ((unsigned int)timeout > (unsigned int)this->params.receive_timeout) timeout = this->params.receive_timeout; } if (this->params.transport_type == TT_TCP_SSL || this->params.transport_type == TT_SCTP_SSL) { if (!cert_file || !key_file) { errno = EINVAL; return -1; } } if (this->CommService::init(bind_addr, addrlen, -1, timeout) < 0) return -1; if (cert_file && key_file && this->params.transport_type != TT_UDP) { SSL_CTX *ssl_ctx = this->new_ssl_ctx(cert_file, key_file); if (!ssl_ctx) { this->deinit(); return -1; } this->set_ssl(ssl_ctx, this->params.ssl_accept_timeout); } this->scheduler = WFGlobal::get_scheduler(); return 0; } int WFServerBase::create_listen_fd() { if (this->listen_fd < 0) { const struct sockaddr *bind_addr; socklen_t addrlen; int type, protocol; int reuse = 1; switch (this->params.transport_type) { case TT_TCP: case TT_TCP_SSL: type = SOCK_STREAM; protocol = 0; break; case TT_UDP: type = SOCK_DGRAM; protocol = 0; break; #ifdef IPPROTO_SCTP case TT_SCTP: case TT_SCTP_SSL: type = SOCK_STREAM; protocol = IPPROTO_SCTP; break; #endif default: errno = EPROTONOSUPPORT; return -1; } this->get_addr(&bind_addr, &addrlen); this->listen_fd = socket(bind_addr->sa_family, type, protocol); if (this->listen_fd >= 0) { setsockopt(this->listen_fd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof (int)); } } else this->listen_fd = dup(this->listen_fd); return this->listen_fd; } WFConnection *WFServerBase::new_connection(int accept_fd) { if (++this->conn_count <= this->params.max_connections || this->drain(1) == 1) { int reuse = 1; setsockopt(accept_fd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof (int)); return new WFServerConnection(&this->conn_count); } this->conn_count--; errno = EMFILE; return NULL; } void WFServerBase::delete_connection(WFConnection *conn) { delete (WFServerConnection *)conn; } void WFServerBase::handle_unbound() { this->mutex.lock(); this->unbind_finish = true; this->cond.notify_one(); this->mutex.unlock(); } int WFServerBase::start(const struct sockaddr *bind_addr, socklen_t addrlen, const char *cert_file, const char *key_file) { SSL_CTX *ssl_ctx; if (this->init(bind_addr, addrlen, cert_file, key_file) >= 0) { if (this->scheduler->bind(this) >= 0) return 0; ssl_ctx = this->get_ssl_ctx(); this->deinit(); if (ssl_ctx) SSL_CTX_free(ssl_ctx); } this->listen_fd = -1; return -1; } int WFServerBase::start(int family, const char *host, unsigned short port, const char *cert_file, const char *key_file) { struct addrinfo hints = { .ai_flags = AI_PASSIVE, .ai_family = family, .ai_socktype = SOCK_STREAM, }; struct addrinfo *addrinfo; char port_str[PORT_STR_MAX + 1]; int ret; snprintf(port_str, PORT_STR_MAX + 1, "%d", port); ret = getaddrinfo(host, port_str, &hints, &addrinfo); if (ret == 0) { ret = start(addrinfo->ai_addr, (socklen_t)addrinfo->ai_addrlen, cert_file, key_file); freeaddrinfo(addrinfo); } else { if (ret != EAI_SYSTEM) errno = EINVAL; ret = -1; } return ret; } int WFServerBase::serve(int listen_fd, const char *cert_file, const char *key_file) { struct sockaddr_storage ss; socklen_t len = sizeof ss; if (getsockname(listen_fd, (struct sockaddr *)&ss, &len) < 0) return -1; this->listen_fd = listen_fd; return start((struct sockaddr *)&ss, len, cert_file, key_file); } void WFServerBase::shutdown() { this->listen_fd = -1; this->scheduler->unbind(this); } void WFServerBase::wait_finish() { SSL_CTX *ssl_ctx = this->get_ssl_ctx(); std::unique_lock lock(this->mutex); while (!this->unbind_finish) this->cond.wait(lock); this->deinit(); this->unbind_finish = false; lock.unlock(); if (ssl_ctx) SSL_CTX_free(ssl_ctx); } workflow-0.11.8/src/server/WFServer.h000066400000000000000000000145531476003635400174440ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com) */ #ifndef _WFSERVER_H_ #define _WFSERVER_H_ #include #include #include #include #include #include #include #include #include "EndpointParams.h" #include "WFTaskFactory.h" struct WFServerParams { enum TransportType transport_type; size_t max_connections; int peer_response_timeout; /* timeout of each read or write operation */ int receive_timeout; /* timeout of receiving the whole message */ int keep_alive_timeout; size_t request_size_limit; int ssl_accept_timeout; /* if not ssl, this will be ignored */ }; static constexpr struct WFServerParams SERVER_PARAMS_DEFAULT = { .transport_type = TT_TCP, .max_connections = 2000, .peer_response_timeout = 10 * 1000, .receive_timeout = -1, .keep_alive_timeout = 60 * 1000, .request_size_limit = (size_t)-1, .ssl_accept_timeout = 10 * 1000, }; class WFServerBase : protected CommService { public: WFServerBase(const struct WFServerParams *params) : conn_count(0) { this->params = *params; this->unbind_finish = false; this->listen_fd = -1; } public: /* To start a TCP server */ /* Start on port with IPv4. */ int start(unsigned short port) { return start(AF_INET, NULL, port, NULL, NULL); } /* Start with family. AF_INET or AF_INET6. */ int start(int family, unsigned short port) { return start(family, NULL, port, NULL, NULL); } /* Start with hostname and port. */ int start(const char *host, unsigned short port) { return start(AF_INET, host, port, NULL, NULL); } /* Start with family, hostname and port. */ int start(int family, const char *host, unsigned short port) { return start(family, host, port, NULL, NULL); } /* Start with binding address. */ int start(const struct sockaddr *bind_addr, socklen_t addrlen) { return start(bind_addr, addrlen, NULL, NULL); } /* To start an SSL server. */ int start(unsigned short port, const char *cert_file, const char *key_file) { return start(AF_INET, NULL, port, cert_file, key_file); } int start(int family, unsigned short port, const char *cert_file, const char *key_file) { return start(family, NULL, port, cert_file, key_file); } int start(const char *host, unsigned short port, const char *cert_file, const char *key_file) { return start(AF_INET, host, port, cert_file, key_file); } int start(int family, const char *host, unsigned short port, const char *cert_file, const char *key_file); /* This is the only necessary start function. */ int start(const struct sockaddr *bind_addr, socklen_t addrlen, const char *cert_file, const char *key_file); /* To start with a specified fd. For graceful restart or SCTP server. */ int serve(int listen_fd) { return serve(listen_fd, NULL, NULL); } int serve(int listen_fd, const char *cert_file, const char *key_file); /* stop() is a blocking operation. */ void stop() { this->shutdown(); this->wait_finish(); } /* Nonblocking terminating the server. For stopping multiple servers. * Typically, call shutdown() and then wait_finish(). * But indeed wait_finish() can be called before shutdown(), even before * start() in another thread. */ void shutdown(); void wait_finish(); public: size_t get_conn_count() const { return this->conn_count; } /* Get the listening address. This is often used after starting * server on a random port (start() with port == 0). */ int get_listen_addr(struct sockaddr *addr, socklen_t *addrlen) const { if (this->listen_fd >= 0) return getsockname(this->listen_fd, addr, addrlen); errno = ENOTCONN; return -1; } const struct WFServerParams *get_params() const { return &this->params; } protected: /* Override this function to create the initial SSL CTX of the server */ virtual SSL_CTX *new_ssl_ctx(const char *cert_file, const char *key_file); /* Override this function to implement server that supports TLS SNI. * "servername" will be NULL if client does not set a host name. * Returning NULL to indicate that servername is not supported. */ virtual SSL_CTX *get_server_ssl_ctx(const char *servername) { return this->get_ssl_ctx(); } /* This can be used by the implementation of 'new_ssl_ctx'. */ static int ssl_ctx_callback(SSL *ssl, int *al, void *arg); protected: WFServerParams params; protected: virtual int create_listen_fd(); virtual WFConnection *new_connection(int accept_fd); void delete_connection(WFConnection *conn); private: int init(const struct sockaddr *bind_addr, socklen_t addrlen, const char *cert_file, const char *key_file); virtual void handle_unbound(); protected: std::atomic conn_count; private: int listen_fd; bool unbind_finish; std::mutex mutex; std::condition_variable cond; class CommScheduler *scheduler; }; template class WFServer : public WFServerBase { public: WFServer(const struct WFServerParams *params, std::function *)> proc) : WFServerBase(params), process(std::move(proc)) { } WFServer(std::function *)> proc) : WFServerBase(&SERVER_PARAMS_DEFAULT), process(std::move(proc)) { } protected: virtual CommSession *new_session(long long seq, CommConnection *conn); protected: std::function *)> process; }; template CommSession *WFServer::new_session(long long seq, CommConnection *conn) { using factory = WFNetworkTaskFactory; WFNetworkTask *task; task = factory::create_server_task(this, this->process); task->set_keep_alive(this->params.keep_alive_timeout); task->set_receive_timeout(this->params.receive_timeout); task->get_req()->set_size_limit(this->params.request_size_limit); return task; } #endif workflow-0.11.8/src/server/xmake.lua000066400000000000000000000002231476003635400173650ustar00rootroot00000000000000target("server") set_kind("object") add_files("*.cc") if not has_config("mysql") then remove_files("WFMySQLServer.cc") end workflow-0.11.8/src/util/000077500000000000000000000000001476003635400152275ustar00rootroot00000000000000workflow-0.11.8/src/util/CMakeLists.txt000066400000000000000000000004101476003635400177620ustar00rootroot00000000000000cmake_minimum_required(VERSION 3.6) project(util) set(SRC json_parser.c EncodeStream.cc StringUtil.cc URIParser.cc ) add_library(${PROJECT_NAME} OBJECT ${SRC}) if (KAFKA STREQUAL "y") set(SRC crc32c.c ) add_library("util_kafka" OBJECT ${SRC}) endif () workflow-0.11.8/src/util/EncodeStream.cc000066400000000000000000000055261476003635400201170ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include "list.h" #include "EncodeStream.h" #define ALIGN(x,a) (((x)+(a)-1)&~((a)-1)) #define ENCODE_BUF_SIZE 1024 struct EncodeBuf { struct list_head list; char *pos; char data[ENCODE_BUF_SIZE]; }; void EncodeStream::clear_buf_data() { struct list_head *pos, *tmp; struct EncodeBuf *entry; list_for_each_safe(pos, tmp, &buf_list_) { entry = list_entry(pos, struct EncodeBuf, list); list_del(pos); delete [](char *)entry; } } void EncodeStream::merge() { size_t len = bytes_ - merged_bytes_; struct EncodeBuf *buf; size_t n; char *p; int i; if (len > ENCODE_BUF_SIZE) n = offsetof(struct EncodeBuf, data) + ALIGN(len, 8); else n = sizeof (struct EncodeBuf); buf = (struct EncodeBuf *)new char[n]; p = buf->data; for (i = merged_size_; i < size_; i++) { memcpy(p, vec_[i].iov_base, vec_[i].iov_len); p += vec_[i].iov_len; } buf->pos = buf->data + ALIGN(len, 8); list_add(&buf->list, &buf_list_); vec_[merged_size_].iov_base = buf->data; vec_[merged_size_].iov_len = len; merged_size_++; merged_bytes_ = bytes_; size_ = merged_size_; } void EncodeStream::append_nocopy(const char *data, size_t len) { if (size_ >= max_) { if (merged_size_ + 1 < max_) merge(); else { size_ = max_ + 1; /* Overflow */ return; } } vec_[size_].iov_base = (char *)data; vec_[size_].iov_len = len; size_++; bytes_ += len; } void EncodeStream::append_copy(const char *data, size_t len) { if (size_ >= max_) { if (merged_size_ + 1 < max_) merge(); else { size_ = max_ + 1; /* Overflow */ return; } } struct EncodeBuf *buf = list_entry(buf_list_.prev, struct EncodeBuf, list); if (list_empty(&buf_list_) || buf->pos + len > buf->data + ENCODE_BUF_SIZE) { size_t n; if (len > ENCODE_BUF_SIZE) n = offsetof(struct EncodeBuf, data) + ALIGN(len, 8); else n = sizeof (struct EncodeBuf); buf = (struct EncodeBuf *)new char[n]; buf->pos = buf->data; list_add_tail(&buf->list, &buf_list_); } memcpy(buf->pos, data, len); vec_[size_].iov_base = buf->pos; vec_[size_].iov_len = len; size_++; bytes_ += len; buf->pos += ALIGN(len, 8); if (buf->pos >= buf->data + ENCODE_BUF_SIZE) list_move(&buf->list, &buf_list_); } workflow-0.11.8/src/util/EncodeStream.h000066400000000000000000000054111476003635400177520ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Xie Han (xiehan@sogou-inc.com) Wu Jiaxu (wujiaxu@sogou-inc.com) */ #ifndef _ENCODESTREAM_H_ #define _ENCODESTREAM_H_ #include #include #include #include #include "list.h" /** * @file EncodeStream.h * @brief Encoder toolbox for protocol message encode */ // make sure max > 0 class EncodeStream { public: EncodeStream() { init_vec(NULL, 0); INIT_LIST_HEAD(&buf_list_); } EncodeStream(struct iovec *vectors, int max) { init_vec(vectors, max); INIT_LIST_HEAD(&buf_list_); } ~EncodeStream() { clear_buf_data(); } void reset(struct iovec *vectors, int max) { clear_buf_data(); init_vec(vectors, max); } int size() const { return size_; } size_t bytes() const { return bytes_; } void append_nocopy(const char *data, size_t len); void append_nocopy(const char *data) { append_nocopy(data, strlen(data)); } void append_nocopy(const std::string& data) { append_nocopy(data.c_str(), data.size()); } void append_copy(const char *data, size_t len); void append_copy(const char *data) { append_copy(data, strlen(data)); } void append_copy(const std::string& data) { append_copy(data.c_str(), data.size()); } private: void init_vec(struct iovec *vectors, int max) { vec_ = vectors; max_ = max; bytes_ = 0; size_ = 0; merged_bytes_ = 0; merged_size_ = 0; } void merge(); void clear_buf_data(); private: struct iovec *vec_; int max_; int size_; size_t bytes_; int merged_size_; size_t merged_bytes_; struct list_head buf_list_; }; static inline EncodeStream& operator << (EncodeStream& stream, const char *data) { stream.append_nocopy(data, strlen(data)); return stream; } static inline EncodeStream& operator << (EncodeStream& stream, const std::string& data) { stream.append_nocopy(data.c_str(), data.size()); return stream; } static inline EncodeStream& operator << (EncodeStream& stream, const std::pair& data) { stream.append_nocopy(data.first, data.second); return stream; } static inline EncodeStream& operator << (EncodeStream& stream, int64_t intv) { stream.append_copy(std::to_string(intv)); return stream; } #endif workflow-0.11.8/src/util/LRUCache.h000066400000000000000000000117231476003635400167720ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) Xie Han (xiehan@sogou-inc.com) */ #ifndef _LRUCACHE_H_ #define _LRUCACHE_H_ #include #include "list.h" #include "rbtree.h" /** * @file LRUCache.h * @brief Template LRU Cache */ // RAII: NO. Release ref by LRUCache::release // Thread safety: NO. // DONOT change value by handler, use Cache::put instead template class LRUHandle { public: VALUE value; private: LRUHandle(const KEY& k, const VALUE& v) : value(v), key(k) { } KEY key; struct list_head list; struct rb_node rb; bool in_cache; int ref; template friend class LRUCache; }; // RAII: NO. Release ref by LRUCache::release // Define ValueDeleter(VALUE& v) for value deleter // Thread safety: NO // Make sure KEY operator< usable template class LRUCache { protected: typedef LRUHandle Handle; public: LRUCache() { INIT_LIST_HEAD(&this->not_use); INIT_LIST_HEAD(&this->in_use); this->cache_map.rb_node = NULL; this->max_size = 0; this->size = 0; } ~LRUCache() { struct list_head *pos, *tmp; Handle *e; // Error if caller has an unreleased handle assert(list_empty(&this->in_use)); list_for_each_safe(pos, tmp, &this->not_use) { e = list_entry(pos, Handle, list); assert(e->in_cache); e->in_cache = false; assert(e->ref == 1);// Invariant for not_use_ list. this->unref(e); } } // default max_size=0 means no-limit cache // max_size means max cache number of key-value pairs void set_max_size(size_t max_size) { this->max_size = max_size; } // Remove all cache that are not actively in use. void prune() { struct list_head *pos, *tmp; Handle *e; list_for_each_safe(pos, tmp, &this->not_use) { e = list_entry(pos, Handle, list); assert(e->ref == 1); rb_erase(&e->rb, &this->cache_map); this->erase_node(e); } } // release handle by get/put void release(const Handle *handle) { this->unref(const_cast(handle)); } // get handler // Need call release when handle no longer needed const Handle *get(const KEY& key) { struct rb_node *p = this->cache_map.rb_node; Handle *bound = NULL; Handle *e; while (p) { e = rb_entry(p, Handle, rb); if (!(e->key < key)) { bound = e; p = p->rb_left; } else p = p->rb_right; } if (bound && !(key < bound->key)) { this->ref(bound); return bound; } return NULL; } // put copy // Need call release when handle no longer needed const Handle *put(const KEY& key, VALUE value) { struct rb_node **p = &this->cache_map.rb_node; struct rb_node *parent = NULL; Handle *bound = NULL; Handle *e; while (*p) { parent = *p; e = rb_entry(*p, Handle, rb); if (!(e->key < key)) { bound = e; p = &(*p)->rb_left; } else p = &(*p)->rb_right; } e = new Handle(key, value); e->in_cache = true; e->ref = 2; list_add_tail(&e->list, &this->in_use); this->size++; if (bound && !(key < bound->key)) { rb_replace_node(&bound->rb, &e->rb, &this->cache_map); this->erase_node(bound); } else { rb_link_node(&e->rb, parent, p); rb_insert_color(&e->rb, &this->cache_map); } if (this->max_size > 0) { while (this->size > this->max_size && !list_empty(&this->not_use)) { Handle *tmp = list_entry(this->not_use.next, Handle, list); assert(tmp->ref == 1); rb_erase(&tmp->rb, &this->cache_map); this->erase_node(tmp); } } return e; } // delete from cache, deleter delay called when all inuse-handle release. void del(const KEY& key) { Handle *e = const_cast(this->get(key)); if (e) { this->unref(e); rb_erase(&e->rb, &this->cache_map); this->erase_node(e); } } private: void ref(Handle *e) { if (e->in_cache && e->ref == 1) list_move_tail(&e->list, &this->in_use); e->ref++; } void unref(Handle *e) { assert(e->ref > 0); if (--e->ref == 0) { assert(!e->in_cache); this->value_deleter(e->value); delete e; } else if (e->in_cache && e->ref == 1) list_move_tail(&e->list, &this->not_use); } void erase_node(Handle *e) { assert(e->in_cache); list_del(&e->list); e->in_cache = false; this->size--; this->unref(e); } size_t max_size; size_t size; struct list_head not_use; struct list_head in_use; struct rb_root cache_map; ValueDeleter value_deleter; }; #endif workflow-0.11.8/src/util/StringUtil.cc000066400000000000000000000104601476003635400176430ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include "StringUtil.h" static int __hex_to_int(const char s[2]) { int value = 16; if (s[0] <= '9') value *= s[0] - '0'; else value *= toupper(s[0]) - 'A' + 10; if (s[1] <= '9') value += s[1] - '0'; else value += toupper(s[1]) - 'A' + 10; return value; } static inline char __int_to_hex(int n) { return n <= 9 ? n + '0' : n - 10 + 'A'; } static size_t __url_decode(char *str) { char *dest = str; char *data = str; while (*data) { if (*data == '%' && isxdigit(data[1]) && isxdigit(data[2])) { *dest = __hex_to_int(data + 1); data += 2; } else if (*data == '+') *dest = ' '; else *dest = *data; data++; dest++; } *dest = '\0'; return dest - str; } void StringUtil::url_decode(std::string& str) { str.resize(__url_decode(const_cast(str.c_str()))); } std::string StringUtil::url_encode(const std::string& str) { const char *cur = str.c_str(); const char *end = cur + str.size(); std::string res; while (cur < end) { if (isalnum(*cur) || *cur == '-' || *cur == '_' || *cur == '.' || *cur == '!' || *cur == '~' || *cur == '*' || *cur == '\'' || *cur == '(' || *cur == ')' || *cur == ':' || *cur == '/' || *cur == '@' || *cur == '?' || *cur == '#' || *cur == '&') { res += *cur; } else if (*cur == ' ') { res += '+'; } else { res += '%'; res += __int_to_hex(((const unsigned char)(*cur)) >> 4); res += __int_to_hex(((const unsigned char)(*cur)) % 16); } cur++; } return res; } std::string StringUtil::url_encode_component(const std::string& str) { const char *cur = str.c_str(); const char *end = cur + str.size(); std::string res; while (cur < end) { if (isalnum(*cur) || *cur == '-' || *cur == '_' || *cur == '.' || *cur == '!' || *cur == '~' || *cur == '*' || *cur == '\'' || *cur == '(' || *cur == ')') { res += *cur; } else if (*cur == ' ') { res += '+'; } else { res += '%'; res += __int_to_hex(((const unsigned char)(*cur)) >> 4); res += __int_to_hex(((const unsigned char)(*cur)) % 16); } cur++; } return res; } std::vector StringUtil::split(const std::string& str, char sep) { std::string::const_iterator cur = str.begin(); std::string::const_iterator end = str.end(); std::string::const_iterator next = find(cur, end, sep); std::vector res; while (next != end) { res.emplace_back(cur, next); cur = next + 1; next = std::find(cur, end, sep); } res.emplace_back(cur, next); return res; } std::vector StringUtil::split_filter_empty(const std::string& str, char sep) { std::vector res; std::string::const_iterator cur = str.begin(); std::string::const_iterator end = str.end(); std::string::const_iterator next = find(cur, end, sep); while (next != end) { if (cur < next) res.emplace_back(cur, next); cur = next + 1; next = find(cur, end, sep); } if (cur < next) res.emplace_back(cur, next); return res; } std::string StringUtil::strip(const std::string& str) { std::string res; if (!str.empty()) { const char *cur = str.c_str(); const char *end = cur + str.size(); while (cur < end) { if (!isspace(*cur)) break; cur++; } while (end > cur) { if (!isspace(*(end - 1))) break; end--; } if (end > cur) res.assign(cur, end - cur); } return res; } bool StringUtil::start_with(const std::string& str, const std::string& prefix) { size_t prefix_len = prefix.size(); if (str.size() < prefix_len) return false; for (size_t i = 0; i < prefix_len; i++) { if (str[i] != prefix[i]) return false; } return true; } workflow-0.11.8/src/util/StringUtil.h000066400000000000000000000025171476003635400175110ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #ifndef _STRINGUTIL_H_ #define _STRINGUTIL_H_ #include #include /** * @file StringUtil.h * @brief String toolbox */ // static class class StringUtil { public: static void url_decode(std::string& str); static std::string url_encode(const std::string& str); static std::string url_encode_component(const std::string& str); static std::vector split(const std::string& str, char sep); static std::string strip(const std::string& str); static bool start_with(const std::string& str, const std::string& prefix); //this will filter any empty result, so the result vector has no empty string static std::vector split_filter_empty(const std::string& str, char sep); }; #endif workflow-0.11.8/src/util/URIParser.cc000066400000000000000000000264211476003635400173570ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) Wang Zhulei (wangzhulei@sogou-inc.com) Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include #include "StringUtil.h" #include "URIParser.h" enum { URI_SCHEME, URI_USERINFO, URI_HOST, URI_PORT, URI_QUERY, URI_FRAGMENT, URI_PATH, URI_PART_ELEMENTS, }; //scheme://[userinfo@]host[:port][/path][?query][#fragment] //0-6 (scheme, userinfo, host, port, path, query, fragment) static constexpr unsigned char valid_char[4][256] = { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, }; static unsigned char authority_map[256] = { URI_PART_ELEMENTS, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, URI_FRAGMENT, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, URI_PATH, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, URI_HOST, 0, 0, 0, 0, URI_QUERY, URI_USERINFO, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; ParsedURI::ParsedURI(ParsedURI&& uri) { scheme = uri.scheme; userinfo = uri.userinfo; host = uri.host; port = uri.port; path = uri.path; query = uri.query; fragment = uri.fragment; state = uri.state; error = uri.error; uri.init(); } ParsedURI& ParsedURI::operator= (ParsedURI&& uri) { if (this != &uri) { deinit(); scheme = uri.scheme; userinfo = uri.userinfo; host = uri.host; port = uri.port; path = uri.path; query = uri.query; fragment = uri.fragment; state = uri.state; error = uri.error; uri.init(); } return *this; } void ParsedURI::copy(const ParsedURI& uri) { init(); state = uri.state; error = uri.error; if (state == URI_STATE_SUCCESS) { bool succ = false; do { if (uri.scheme) { scheme = strdup(uri.scheme); if (!scheme) break; } if (uri.userinfo) { userinfo = strdup(uri.userinfo); if (!userinfo) break; } if (uri.host) { host = strdup(uri.host); if (!host) break; } if (uri.port) { port = strdup(uri.port); if (!port) break; } if (uri.path) { path = strdup(uri.path); if (!path) break; } if (uri.query) { query = strdup(uri.query); if (!query) break; } if (uri.fragment) { fragment = strdup(uri.fragment); if (!fragment) break; } succ = true; } while (0); if (!succ) { deinit(); init(); state = URI_STATE_ERROR; error = errno; } } } int URIParser::parse(const char *str, ParsedURI& uri) { uri.state = URI_STATE_INVALID; int start_idx[URI_PART_ELEMENTS] = {0}; int end_idx[URI_PART_ELEMENTS] = {0}; int pre_state = URI_SCHEME; bool in_ipv6 = false; int i; for (i = 0; str[i]; i++) { if (str[i] == ':') { end_idx[URI_SCHEME] = i++; break; } } if (end_idx[URI_SCHEME] == 0) return -1; if (str[i] == '/' && str[i + 1] == '/') { pre_state = URI_HOST; i += 2; if (str[i] == '[') in_ipv6 = true; else start_idx[URI_USERINFO] = i; start_idx[URI_HOST] = i; } else { pre_state = URI_PATH; start_idx[URI_PATH] = i; } bool skip_path = false; if (start_idx[URI_PATH] == 0) { for (; ; i++) { switch (authority_map[(unsigned char)str[i]]) { case 0: continue; case URI_USERINFO: if (str[i + 1] == '[') in_ipv6 = true; end_idx[URI_USERINFO] = i; start_idx[URI_HOST] = i + 1; pre_state = URI_HOST; continue; case URI_HOST: if (str[i - 1] == ']') in_ipv6 = false; if (!in_ipv6) { end_idx[URI_HOST] = i; start_idx[URI_PORT] = i + 1; pre_state = URI_PORT; } continue; case URI_QUERY: end_idx[pre_state] = i; start_idx[URI_QUERY] = i + 1; pre_state = URI_QUERY; skip_path = true; continue; case URI_FRAGMENT: end_idx[pre_state] = i; start_idx[URI_FRAGMENT] = i + 1; end_idx[URI_FRAGMENT] = i + strlen(str + i); pre_state = URI_PART_ELEMENTS; skip_path = true; break; case URI_PATH: if (skip_path) continue; start_idx[URI_PATH] = i; break; case URI_PART_ELEMENTS: skip_path = true; break; } break; } } if (pre_state != URI_PART_ELEMENTS) end_idx[pre_state] = i; if (!skip_path) { pre_state = URI_PATH; for (; str[i]; i++) { if (str[i] == '?') { end_idx[URI_PATH] = i; start_idx[URI_QUERY] = i + 1; pre_state = URI_QUERY; while (str[i + 1]) { if (str[++i] == '#') break; } } if (str[i] == '#') { end_idx[pre_state] = i; start_idx[URI_FRAGMENT] = i + 1; pre_state = URI_FRAGMENT; break; } } end_idx[pre_state] = i + strlen(str + i); } for (int i = 0; i < URI_QUERY; i++) { for (int j = start_idx[i]; j < end_idx[i]; j++) { if (!valid_char[i][(unsigned char)str[j]]) return -1;//invalid char } } char **dst[URI_PART_ELEMENTS] = {&uri.scheme, &uri.userinfo, &uri.host, &uri.port, &uri.query, &uri.fragment, &uri.path}; for (int i = 0; i < URI_PART_ELEMENTS; i++) { if (end_idx[i] > start_idx[i]) { size_t len = end_idx[i] - start_idx[i]; *dst[i] = (char *)realloc(*dst[i], len + 1); if (*dst[i] == NULL) { uri.state = URI_STATE_ERROR; uri.error = errno; return -1; } if (i == URI_HOST && str[start_idx[i]] == '[' && str[end_idx[i] - 1] == ']') { len -= 2; memcpy(*dst[i], str + start_idx[i] + 1, len); } else memcpy(*dst[i], str + start_idx[i], len); (*dst[i])[len] = '\0'; } else { free(*dst[i]); *dst[i] = NULL; } } uri.state = URI_STATE_SUCCESS; return 0; } std::map> URIParser::split_query_strict(const std::string &query) { std::map> res; if (query.empty()) return res; std::vector arr = StringUtil::split(query, '&'); if (arr.empty()) return res; for (const auto& ele : arr) { if (ele.empty()) continue; std::vector kv = StringUtil::split(ele, '='); size_t kv_size = kv.size(); std::string& key = kv[0]; if (key.empty()) continue; if (kv_size == 1) { res[key].emplace_back(); continue; } std::string& val = kv[1]; if (val.empty()) res[key].emplace_back(); else res[key].emplace_back(std::move(val)); } return res; } std::map URIParser::split_query(const std::string &query) { std::map res; if (query.empty()) return res; std::vector arr = StringUtil::split(query, '&'); if (arr.empty()) return res; for (const auto& ele : arr) { if (ele.empty()) continue; std::vector kv = StringUtil::split(ele, '='); size_t kv_size = kv.size(); std::string& key = kv[0]; if (key.empty() || res.count(key) > 0) continue; if (kv_size == 1) { res.emplace(std::move(key), ""); continue; } std::string& val = kv[1]; if (val.empty()) res.emplace(std::move(key), ""); else res.emplace(std::move(key), std::move(val)); } return res; } std::vector URIParser::split_path(const std::string &path) { return StringUtil::split_filter_empty(path, '/'); } workflow-0.11.8/src/util/URIParser.h000066400000000000000000000044651476003635400172250ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) Wang Zhulei (wangzhulei@sogou-inc.com) */ #ifndef _URIPARSER_H_ #define _URIPARSER_H_ #include #include #include #include #define URI_STATE_INIT 0 #define URI_STATE_SUCCESS 1 #define URI_STATE_INVALID 2 #define URI_STATE_ERROR 3 /** * @file URIParser.h * @brief URI parser */ // RAII: YES class ParsedURI { public: char *scheme; char *userinfo; char *host; char *port; char *path; char *query; char *fragment; int state; int error; ParsedURI() { init(); } virtual ~ParsedURI() { deinit(); } //copy constructor ParsedURI(const ParsedURI& uri) { copy(uri); } //copy operator ParsedURI& operator= (const ParsedURI& uri) { if (this != &uri) { deinit(); copy(uri); } return *this; } //move constructor ParsedURI(ParsedURI&& uri); //move operator ParsedURI& operator= (ParsedURI&& uri); private: void init() { scheme = NULL; userinfo = NULL; host = NULL; port = NULL; path = NULL; query = NULL; fragment = NULL; state = URI_STATE_INIT; error = 0; } void deinit() { free(scheme); free(userinfo); free(host); free(port); free(path); free(query); free(fragment); } void copy(const ParsedURI& uri); }; // static class class URIParser { public: // return 0 mean succ, -1 mean fail static int parse(const char *str, ParsedURI& uri); static int parse(const std::string& str, ParsedURI& uri) { return parse(str.c_str(), uri); } static std::map> split_query_strict(const std::string &query); static std::map split_query(const std::string &query); static std::vector split_path(const std::string &path); }; #endif workflow-0.11.8/src/util/crc32c.c000066400000000000000000000360221476003635400164550ustar00rootroot00000000000000/* Copied from http://stackoverflow.com/a/17646775/1821055 * with the following modifications: * * remove test code * * global hw/sw initialization to be called once per process * * HW support is determined by configure's WITH_CRC32C_HW * * Windows porting (no hardware support on Windows yet) * * FIXME: * * Hardware support on Windows (MSVC assembler) * * Hardware support on ARM */ /* crc32c.c -- compute CRC-32C using the Intel crc32 instruction * Copyright (C) 2013 Mark Adler * Version 1.1 1 Aug 2013 Mark Adler */ /* This software is provided 'as-is', without any express or implied warranty. In no event will the author be held liable for any damages arising from the use of this software. Permission is granted to anyone to use this software for any purpose, including commercial applications, and to alter it and redistribute it freely, subject to the following restrictions: 1. The origin of this software must not be misrepresented; you must not claim that you wrote the original software. If you use this software in a product, an acknowledgment in the product documentation would be appreciated but is not required. 2. Altered source versions must be plainly marked as such, and must not be misrepresented as being the original software. 3. This notice may not be removed or altered from any source distribution. Mark Adler madler@alumni.caltech.edu */ /* Use hardware CRC instruction on Intel SSE 4.2 processors. This computes a CRC-32C, *not* the CRC-32 used by Ethernet and zip, gzip, etc. A software version is provided as a fall-back, as well as for speed comparisons. */ /* Version history: 1.0 10 Feb 2013 First version 1.1 1 Aug 2013 Correct comments on why three crc instructions in parallel */ #include #include #include #include /** * Provides portable endian-swapping macros/functions. * * be64toh() * htobe64() * be32toh() * htobe32() * be16toh() * htobe16() * le64toh() */ #ifdef __FreeBSD__ #include #elif defined __GLIBC__ #include #ifndef be64toh /* Support older glibc (<2.9) which lack be64toh */ #include #if __BYTE_ORDER == __BIG_ENDIAN #define be16toh(x) (x) #define be32toh(x) (x) #define be64toh(x) (x) #define le64toh(x) __bswap_64 (x) #define le32toh(x) __bswap_32 (x) #else #define be16toh(x) __bswap_16 (x) #define be32toh(x) __bswap_32 (x) #define be64toh(x) __bswap_64 (x) #define le64toh(x) (x) #define le32toh(x) (x) #endif #endif #elif defined __CYGWIN__ #include #elif defined __BSD__ #include #elif defined __sun #include #include #define __LITTLE_ENDIAN 1234 #define __BIG_ENDIAN 4321 #ifdef _BIG_ENDIAN #define __BYTE_ORDER __BIG_ENDIAN #define be64toh(x) (x) #define be32toh(x) (x) #define be16toh(x) (x) #define le16toh(x) ((uint16_t)BSWAP_16(x)) #define le32toh(x) BSWAP_32(x) #define le64toh(x) BSWAP_64(x) # else #define __BYTE_ORDER __LITTLE_ENDIAN #define be64toh(x) BSWAP_64(x) #define be32toh(x) ntohl(x) #define be16toh(x) ntohs(x) #define le16toh(x) (x) #define le32toh(x) (x) #define le64toh(x) (x) #define htole16(x) (x) #define htole64(x) (x) #endif /* __sun */ #elif defined __APPLE__ #include #include #if __DARWIN_BYTE_ORDER == __DARWIN_BIG_ENDIAN #define be64toh(x) (x) #define be32toh(x) (x) #define be16toh(x) (x) #define le16toh(x) OSSwapInt16(x) #define le32toh(x) OSSwapInt32(x) #define le64toh(x) OSSwapInt64(x) #else #define be64toh(x) OSSwapInt64(x) #define be32toh(x) OSSwapInt32(x) #define be16toh(x) OSSwapInt16(x) #define le16toh(x) (x) #define le32toh(x) (x) #define le64toh(x) (x) #endif #elif defined(_WIN32) #include #define be64toh(x) _byteswap_uint64(x) #define be32toh(x) _byteswap_ulong(x) #define be16toh(x) _byteswap_ushort(x) #define le16toh(x) (x) #define le32toh(x) (x) #define le64toh(x) (x) #elif defined _AIX /* AIX is always big endian */ #define be64toh(x) (x) #define be32toh(x) (x) #define be16toh(x) (x) #define le32toh(x) \ ((((x) & 0xff) << 24) | \ (((x) & 0xff00) << 8) | \ (((x) & 0xff0000) >> 8) | \ (((x) & 0xff000000) >> 24)) #define le64toh(x) \ ((((x) & 0x00000000000000ffL) << 56) | \ (((x) & 0x000000000000ff00L) << 40) | \ (((x) & 0x0000000000ff0000L) << 24) | \ (((x) & 0x00000000ff000000L) << 8) | \ (((x) & 0x000000ff00000000L) >> 8) | \ (((x) & 0x0000ff0000000000L) >> 24) | \ (((x) & 0x00ff000000000000L) >> 40) | \ (((x) & 0xff00000000000000L) >> 56)) #else #include #endif /* * On Solaris, be64toh is a function, not a macro, so there's no need to error * if it's not defined. */ #if !defined(__sun) && !defined(be64toh) #error Missing definition for be64toh #endif #ifndef be32toh #define be32toh(x) ntohl(x) #endif #ifndef be16toh #define be16toh(x) ntohs(x) #endif #ifndef htobe64 #define htobe64(x) be64toh(x) #endif #ifndef htobe32 #define htobe32(x) be32toh(x) #endif #ifndef htobe16 #define htobe16(x) be16toh(x) #endif #ifndef htole32 #define htole32(x) le32toh(x) #endif /* CRC-32C (iSCSI) polynomial in reversed bit order. */ #define POLY 0x82f63b78 /* Table for a quadword-at-a-time software crc. */ static uint32_t crc32c_table[8][256]; /* Construct table for software CRC-32C calculation. */ static void crc32c_init_sw(void) { uint32_t n, crc, k; for (n = 0; n < 256; n++) { crc = n; crc = crc & 1 ? (crc >> 1) ^ POLY : crc >> 1; crc = crc & 1 ? (crc >> 1) ^ POLY : crc >> 1; crc = crc & 1 ? (crc >> 1) ^ POLY : crc >> 1; crc = crc & 1 ? (crc >> 1) ^ POLY : crc >> 1; crc = crc & 1 ? (crc >> 1) ^ POLY : crc >> 1; crc = crc & 1 ? (crc >> 1) ^ POLY : crc >> 1; crc = crc & 1 ? (crc >> 1) ^ POLY : crc >> 1; crc = crc & 1 ? (crc >> 1) ^ POLY : crc >> 1; crc32c_table[0][n] = crc; } for (n = 0; n < 256; n++) { crc = crc32c_table[0][n]; for (k = 1; k < 8; k++) { crc = crc32c_table[0][crc & 0xff] ^ (crc >> 8); crc32c_table[k][n] = crc; } } } /* Table-driven software version as a fall-back. This is about 15 times slower than using the hardware instructions. This assumes little-endian integers, as is the case on Intel processors that the assembler code here is for. */ static uint32_t crc32c_sw(uint32_t crci, const void *buf, size_t len) { const unsigned char *next = buf; uint64_t crc; crc = crci ^ 0xffffffff; while (len && ((uintptr_t)next & 7) != 0) { crc = crc32c_table[0][(crc ^ *next++) & 0xff] ^ (crc >> 8); len--; } while (len >= 8) { /* Alignment-safe */ uint64_t ncopy; memcpy(&ncopy, next, sizeof(ncopy)); crc ^= le64toh(ncopy); crc = crc32c_table[7][crc & 0xff] ^ crc32c_table[6][(crc >> 8) & 0xff] ^ crc32c_table[5][(crc >> 16) & 0xff] ^ crc32c_table[4][(crc >> 24) & 0xff] ^ crc32c_table[3][(crc >> 32) & 0xff] ^ crc32c_table[2][(crc >> 40) & 0xff] ^ crc32c_table[1][(crc >> 48) & 0xff] ^ crc32c_table[0][crc >> 56]; next += 8; len -= 8; } while (len) { crc = crc32c_table[0][(crc ^ *next++) & 0xff] ^ (crc >> 8); len--; } return (uint32_t)crc ^ 0xffffffff; } #if WITH_CRC32C_HW static int sse42; /* Cached SSE42 support */ /* Multiply a matrix times a vector over the Galois field of two elements, GF(2). Each element is a bit in an unsigned integer. mat must have at least as many entries as the power of two for most significant one bit in vec. */ static RD_INLINE uint32_t gf2_matrix_times(uint32_t *mat, uint32_t vec) { uint32_t sum; sum = 0; while (vec) { if (vec & 1) sum ^= *mat; vec >>= 1; mat++; } return sum; } /* Multiply a matrix by itself over GF(2). Both mat and square must have 32 rows. */ static RD_INLINE void gf2_matrix_square(uint32_t *square, uint32_t *mat) { int n; for (n = 0; n < 32; n++) square[n] = gf2_matrix_times(mat, mat[n]); } /* Construct an operator to apply len zeros to a crc. len must be a power of two. If len is not a power of two, then the result is the same as for the largest power of two less than len. The result for len == 0 is the same as for len == 1. A version of this routine could be easily written for any len, but that is not needed for this application. */ static void crc32c_zeros_op(uint32_t *even, size_t len) { int n; uint32_t row; uint32_t odd[32]; /* odd-power-of-two zeros operator */ /* put operator for one zero bit in odd */ odd[0] = POLY; /* CRC-32C polynomial */ row = 1; for (n = 1; n < 32; n++) { odd[n] = row; row <<= 1; } /* put operator for two zero bits in even */ gf2_matrix_square(even, odd); /* put operator for four zero bits in odd */ gf2_matrix_square(odd, even); /* first square will put the operator for one zero byte (eight zero bits), in even -- next square puts operator for two zero bytes in odd, and so on, until len has been rotated down to zero */ do { gf2_matrix_square(even, odd); len >>= 1; if (len == 0) return; gf2_matrix_square(odd, even); len >>= 1; } while (len); /* answer ended up in odd -- copy to even */ for (n = 0; n < 32; n++) even[n] = odd[n]; } /* Take a length and build four lookup tables for applying the zeros operator for that length, byte-by-byte on the operand. */ static void crc32c_zeros(uint32_t zeros[][256], size_t len) { uint32_t n; uint32_t op[32]; crc32c_zeros_op(op, len); for (n = 0; n < 256; n++) { zeros[0][n] = gf2_matrix_times(op, n); zeros[1][n] = gf2_matrix_times(op, n << 8); zeros[2][n] = gf2_matrix_times(op, n << 16); zeros[3][n] = gf2_matrix_times(op, n << 24); } } /* Apply the zeros operator table to crc. */ static RD_INLINE uint32_t crc32c_shift(uint32_t zeros[][256], uint32_t crc) { return zeros[0][crc & 0xff] ^ zeros[1][(crc >> 8) & 0xff] ^ zeros[2][(crc >> 16) & 0xff] ^ zeros[3][crc >> 24]; } /* Block sizes for three-way parallel crc computation. LONG and SHORT must both be powers of two. The associated string constants must be set accordingly, for use in constructing the assembler instructions. */ #define LONG 8192 #define LONGx1 "8192" #define LONGx2 "16384" #define SHORT 256 #define SHORTx1 "256" #define SHORTx2 "512" /* Tables for hardware crc that shift a crc by LONG and SHORT zeros. */ static uint32_t crc32c_long[4][256]; static uint32_t crc32c_short[4][256]; /* Initialize tables for shifting crcs. */ static void crc32c_init_hw(void) { crc32c_zeros(crc32c_long, LONG); crc32c_zeros(crc32c_short, SHORT); } /* Compute CRC-32C using the Intel hardware instruction. */ static uint32_t crc32c_hw(uint32_t crc, const void *buf, size_t len) { const unsigned char *next = buf; const unsigned char *end; uint64_t crc0, crc1, crc2; /* need to be 64 bits for crc32q */ /* pre-process the crc */ crc0 = crc ^ 0xffffffff; /* compute the crc for up to seven leading bytes to bring the data pointer to an eight-byte boundary */ while (len && ((uintptr_t)next & 7) != 0) { __asm__("crc32b\t" "(%1), %0" : "=r"(crc0) : "r"(next), "0"(crc0)); next++; len--; } /* compute the crc on sets of LONG*3 bytes, executing three independent crc instructions, each on LONG bytes -- this is optimized for the Nehalem, Westmere, Sandy Bridge, and Ivy Bridge architectures, which have a throughput of one crc per cycle, but a latency of three cycles */ while (len >= LONG*3) { crc1 = 0; crc2 = 0; end = next + LONG; do { __asm__("crc32q\t" "(%3), %0\n\t" "crc32q\t" LONGx1 "(%3), %1\n\t" "crc32q\t" LONGx2 "(%3), %2" : "=r"(crc0), "=r"(crc1), "=r"(crc2) : "r"(next), "0"(crc0), "1"(crc1), "2"(crc2)); next += 8; } while (next < end); crc0 = crc32c_shift(crc32c_long, crc0) ^ crc1; crc0 = crc32c_shift(crc32c_long, crc0) ^ crc2; next += LONG*2; len -= LONG*3; } /* do the same thing, but now on SHORT*3 blocks for the remaining data less than a LONG*3 block */ while (len >= SHORT*3) { crc1 = 0; crc2 = 0; end = next + SHORT; do { __asm__("crc32q\t" "(%3), %0\n\t" "crc32q\t" SHORTx1 "(%3), %1\n\t" "crc32q\t" SHORTx2 "(%3), %2" : "=r"(crc0), "=r"(crc1), "=r"(crc2) : "r"(next), "0"(crc0), "1"(crc1), "2"(crc2)); next += 8; } while (next < end); crc0 = crc32c_shift(crc32c_short, crc0) ^ crc1; crc0 = crc32c_shift(crc32c_short, crc0) ^ crc2; next += SHORT*2; len -= SHORT*3; } /* compute the crc on the remaining eight-byte units less than a SHORT*3 block */ end = next + (len - (len & 7)); while (next < end) { __asm__("crc32q\t" "(%1), %0" : "=r"(crc0) : "r"(next), "0"(crc0)); next += 8; } len &= 7; /* compute the crc for up to seven trailing bytes */ while (len) { __asm__("crc32b\t" "(%1), %0" : "=r"(crc0) : "r"(next), "0"(crc0)); next++; len--; } /* return a post-processed crc */ return (uint32_t)crc0 ^ 0xffffffff; } /* Check for SSE 4.2. SSE 4.2 was first supported in Nehalem processors introduced in November, 2008. This does not check for the existence of the cpuid instruction itself, which was introduced on the 486SL in 1992, so this will fail on earlier x86 processors. cpuid works on all Pentium and later processors. */ #define SSE42(have) \ do { \ uint32_t eax, ecx; \ eax = 1; \ __asm__("cpuid" \ : "=c"(ecx) \ : "a"(eax) \ : "%ebx", "%edx"); \ (have) = (ecx >> 20) & 1; \ } while (0) #endif /* WITH_CRC32C_HW */ /* Compute a CRC-32C. If the crc32 instruction is available, use the hardware version. Otherwise, use the software version. */ uint32_t crc32c(uint32_t crc, const void *buf, size_t len) { #if WITH_CRC32C_HW if (sse42) return crc32c_hw(crc, buf, len); else #endif return crc32c_sw(crc, buf, len); } /** * @brief Populate shift tables once */ void crc32c_global_init (void) { #if WITH_CRC32C_HW SSE42(sse42); if (sse42) crc32c_init_hw(); else #endif crc32c_init_sw(); } workflow-0.11.8/src/util/crc32c.h000066400000000000000000000031511476003635400164570ustar00rootroot00000000000000/* * Copyright (c) 2017 Magnus Edenhill * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #ifndef _CRC32C_H_ #define _CRC32C_H_ #include #include #ifdef __cplusplus extern "C" { #endif uint32_t crc32c(uint32_t crc, const void *buf, size_t len); void crc32c_global_init (void); #ifdef __cplusplus } #endif #endif /* _CRC32C_H_ */ workflow-0.11.8/src/util/json_parser.c000066400000000000000000000630541476003635400177300ustar00rootroot00000000000000/* Copyright (c) 2022 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #include #include #include #include #include #include #include "list.h" #include "rbtree.h" #include "json_parser.h" #define JSON_DEPTH_LIMIT 1024 struct __json_object { struct list_head head; struct rb_root root; int size; }; struct __json_array { struct list_head head; int size; }; struct __json_value { union { char *string; double number; json_object_t object; json_array_t array; } value; int type; }; struct __json_member { struct list_head list; struct rb_node rb; json_value_t value; char name[1]; }; struct __json_element { struct list_head list; json_value_t value; }; typedef struct __json_member json_member_t; typedef struct __json_element json_element_t; static int __json_string_length(const char *cursor) { int len = 0; while (*cursor != '\"') { if (*cursor == '\\') { cursor++; if (*cursor == '\0') return -2; } else if ((unsigned char)*cursor < 0x20) return -2; cursor++; len++; } return len; } static int __parse_json_hex4(const char *cursor, const char **end, unsigned int *code) { int hex; int i; *code = 0; for (i = 0; i < 4; i++) { hex = *cursor; if (hex >= '0' && hex <= '9') hex = hex - '0'; else if (hex >= 'A' && hex <= 'F') hex = hex - 'A' + 10; else if (hex >= 'a' && hex <= 'f') hex = hex - 'a' + 10; else return -2; *code = (*code << 4) + hex; cursor++; } *end = cursor; return 0; } static int __parse_json_unicode(const char *cursor, const char **end, char *utf8) { unsigned int code; unsigned int next; int ret; ret = __parse_json_hex4(cursor, end, &code); if (ret < 0) return ret; if (code >= 0xdc00 && code <= 0xdfff) return -2; if (code >= 0xd800 && code <= 0xdbff) { cursor = *end; if (*cursor != '\\') return -2; cursor++; if (*cursor != 'u') return -2; cursor++; ret = __parse_json_hex4(cursor, end, &next); if (ret < 0) return ret; if (next < 0xdc00 || next > 0xdfff) return -2; code = (((code & 0x3ff) << 10) | (next & 0x3ff)) + 0x10000; } if (code <= 0x7f) { utf8[0] = code; return 1; } else if (code <= 0x7ff) { utf8[0] = 0xc0 | (code >> 6); utf8[1] = 0x80 | (code & 0x3f); return 2; } else if (code <= 0xffff) { utf8[0] = 0xe0 | (code >> 12); utf8[1] = 0x80 | ((code >> 6) & 0x3f); utf8[2] = 0x80 | (code & 0x3f); return 3; } else { utf8[0] = 0xf0 | (code >> 18); utf8[1] = 0x80 | ((code >> 12) & 0x3f); utf8[2] = 0x80 | ((code >> 6) & 0x3f); utf8[3] = 0x80 | (code & 0x3f); return 4; } } static int __parse_json_string(const char *cursor, const char **end, char *str) { int ret; while (*cursor != '\"') { if (*cursor == '\\') { cursor++; switch (*cursor) { case '\"': *str = '\"'; break; case '\\': *str = '\\'; break; case '/': *str = '/'; break; case 'b': *str = '\b'; break; case 'f': *str = '\f'; break; case 'n': *str = '\n'; break; case 'r': *str = '\r'; break; case 't': *str = '\t'; break; case 'u': cursor++; ret = __parse_json_unicode(cursor, &cursor, str); if (ret < 0) return ret; str += ret; continue; default: return -2; } } else *str = *cursor; cursor++; str++; } *str = '\0'; *end = cursor + 1; return 0; } static const double __power_of_10[309] = { 1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19, 1e20, 1e21, 1e22, 1e23, 1e24, 1e25, 1e26, 1e27, 1e28, 1e29, 1e30, 1e31, 1e32, 1e33, 1e34, 1e35, 1e36, 1e37, 1e38, 1e39, 1e40, 1e41, 1e42, 1e43, 1e44, 1e45, 1e46, 1e47, 1e48, 1e49, 1e50, 1e51, 1e52, 1e53, 1e54, 1e55, 1e56, 1e57, 1e58, 1e59, 1e60, 1e61, 1e62, 1e63, 1e64, 1e65, 1e66, 1e67, 1e68, 1e69, 1e70, 1e71, 1e72, 1e73, 1e74, 1e75, 1e76, 1e77, 1e78, 1e79, 1e80, 1e81, 1e82, 1e83, 1e84, 1e85, 1e86, 1e87, 1e88, 1e89, 1e90, 1e91, 1e92, 1e93, 1e94, 1e95, 1e96, 1e97, 1e98, 1e99, 1e100, 1e101, 1e102, 1e103, 1e104, 1e105, 1e106, 1e107, 1e108, 1e109, 1e110, 1e111, 1e112, 1e113, 1e114, 1e115, 1e116, 1e117, 1e118, 1e119, 1e120, 1e121, 1e122, 1e123, 1e124, 1e125, 1e126, 1e127, 1e128, 1e129, 1e130, 1e131, 1e132, 1e133, 1e134, 1e135, 1e136, 1e137, 1e138, 1e139, 1e140, 1e141, 1e142, 1e143, 1e144, 1e145, 1e146, 1e147, 1e148, 1e149, 1e150, 1e151, 1e152, 1e153, 1e154, 1e155, 1e156, 1e157, 1e158, 1e159, 1e160, 1e161, 1e162, 1e163, 1e164, 1e165, 1e166, 1e167, 1e168, 1e169, 1e170, 1e171, 1e172, 1e173, 1e174, 1e175, 1e176, 1e177, 1e178, 1e179, 1e180, 1e181, 1e182, 1e183, 1e184, 1e185, 1e186, 1e187, 1e188, 1e189, 1e190, 1e191, 1e192, 1e193, 1e194, 1e195, 1e196, 1e197, 1e198, 1e199, 1e200, 1e201, 1e202, 1e203, 1e204, 1e205, 1e206, 1e207, 1e208, 1e209, 1e210, 1e211, 1e212, 1e213, 1e214, 1e215, 1e216, 1e217, 1e218, 1e219, 1e220, 1e221, 1e222, 1e223, 1e224, 1e225, 1e226, 1e227, 1e228, 1e229, 1e230, 1e231, 1e232, 1e233, 1e234, 1e235, 1e236, 1e237, 1e238, 1e239, 1e240, 1e241, 1e242, 1e243, 1e244, 1e245, 1e246, 1e247, 1e248, 1e249, 1e250, 1e251, 1e252, 1e253, 1e254, 1e255, 1e256, 1e257, 1e258, 1e259, 1e260, 1e261, 1e262, 1e263, 1e264, 1e265, 1e266, 1e267, 1e268, 1e269, 1e270, 1e271, 1e272, 1e273, 1e274, 1e275, 1e276, 1e277, 1e278, 1e279, 1e280, 1e281, 1e282, 1e283, 1e284, 1e285, 1e286, 1e287, 1e288, 1e289, 1e290, 1e291, 1e292, 1e293, 1e294, 1e295, 1e296, 1e297, 1e298, 1e299, 1e300, 1e301, 1e302, 1e303, 1e304, 1e305, 1e306, 1e307, 1e308 }; static double __evaluate_json_number(const char *integer, const char *fraction, int exp) { long long mant = 0; int figures = 0; double num; int sign; sign = (*integer == '-'); if (sign) integer++; if (*integer != '0') { mant = *integer - '0'; integer++; figures++; while (isdigit(*integer) && figures < 18) { mant *= 10; mant += *integer - '0'; integer++; figures++; } while (isdigit(*integer)) { exp++; integer++; } } else { while (*fraction == '0') { exp--; fraction++; } } while (isdigit(*fraction) && figures < 18) { mant *= 10; mant += *fraction - '0'; exp--; fraction++; figures++; } if (exp != 0 && figures != 0) { while (exp > 0 && figures < 18) { mant *= 10; exp--; figures++; } while (exp < 0 && mant % 10 == 0) { mant /= 10; exp++; figures--; } } num = mant; if (exp != 0 && figures != 0) { if (exp > 291) num = INFINITY; else if (exp > 0) num *= __power_of_10[exp]; else if (exp > -309) num /= __power_of_10[-exp]; else if (exp > -324 - figures) { num /= __power_of_10[-exp - 308]; num /= __power_of_10[308]; } else num = 0.0; } return sign ? -num : num; } static int __parse_json_number(const char *cursor, const char **end, double *num) { const char *integer = cursor; const char *fraction = ""; int exp = 0; int sign; if (*cursor == '-') cursor++; if (!isdigit(*cursor)) return -2; if (*cursor == '0' && isdigit(cursor[1])) return -2; cursor++; while (isdigit(*cursor)) cursor++; if (*cursor == '.') { cursor++; fraction = cursor; if (!isdigit(*cursor)) return -2; cursor++; while (isdigit(*cursor)) cursor++; } if (*cursor == 'E' || *cursor == 'e') { cursor++; sign = (*cursor == '-'); if (sign || *cursor == '+') cursor++; if (!isdigit(*cursor)) return -2; exp = *cursor - '0'; cursor++; while (isdigit(*cursor) && exp < 2000000) { exp *= 10; exp += *cursor - '0'; cursor++; } while (isdigit(*cursor)) cursor++; if (sign) exp = -exp; } if (cursor - integer > 1000000) return -2; *num = __evaluate_json_number(integer, fraction, exp); *end = cursor; return 0; } static void __insert_json_member(json_member_t *memb, struct list_head *pos, json_object_t *obj) { struct rb_node **p = &obj->root.rb_node; struct rb_node *parent = NULL; json_member_t *entry; while (*p) { parent = *p; entry = rb_entry(*p, json_member_t, rb); if (strcmp(memb->name, entry->name) < 0) p = &(*p)->rb_left; else p = &(*p)->rb_right; } rb_link_node(&memb->rb, parent, p); rb_insert_color(&memb->rb, &obj->root); list_add(&memb->list, pos); } static int __parse_json_value(const char *cursor, const char **end, int depth, json_value_t *val); static void __destroy_json_value(json_value_t *val); static int __parse_json_member(const char *cursor, const char **end, int depth, json_member_t *memb) { int ret; ret = __parse_json_string(cursor, &cursor, memb->name); if (ret < 0) return ret; while (isspace(*cursor)) cursor++; if (*cursor != ':') return -2; cursor++; while (isspace(*cursor)) cursor++; ret = __parse_json_value(cursor, &cursor, depth, &memb->value); if (ret < 0) return ret; *end = cursor; return 0; } static int __parse_json_members(const char *cursor, const char **end, int depth, json_object_t *obj) { json_member_t *memb; int cnt = 0; int ret; while (isspace(*cursor)) cursor++; if (*cursor == '}') { *end = cursor + 1; return 0; } while (1) { if (*cursor != '\"') return -2; cursor++; ret = __json_string_length(cursor); if (ret < 0) return ret; memb = (json_member_t *)malloc(offsetof(json_member_t, name) + ret + 1); if (!memb) return -1; ret = __parse_json_member(cursor, &cursor, depth, memb); if (ret < 0) { free(memb); return ret; } __insert_json_member(memb, obj->head.prev, obj); cnt++; while (isspace(*cursor)) cursor++; if (*cursor == ',') { cursor++; while (isspace(*cursor)) cursor++; } else if (*cursor == '}') break; else return -2; } *end = cursor + 1; return cnt; } static void __destroy_json_members(json_object_t *obj) { struct list_head *pos, *tmp; json_member_t *memb; list_for_each_safe(pos, tmp, &obj->head) { memb = list_entry(pos, json_member_t, list); __destroy_json_value(&memb->value); free(memb); } } static int __parse_json_object(const char *cursor, const char **end, int depth, json_object_t *obj) { int ret; if (depth == JSON_DEPTH_LIMIT) return -3; INIT_LIST_HEAD(&obj->head); obj->root.rb_node = NULL; ret = __parse_json_members(cursor, end, depth + 1, obj); if (ret < 0) { __destroy_json_members(obj); return ret; } obj->size = ret; return 0; } static int __parse_json_elements(const char *cursor, const char **end, int depth, json_array_t *arr) { json_element_t *elem; int cnt = 0; int ret; while (isspace(*cursor)) cursor++; if (*cursor == ']') { *end = cursor + 1; return 0; } while (1) { elem = (json_element_t *)malloc(sizeof (json_element_t)); if (!elem) return -1; ret = __parse_json_value(cursor, &cursor, depth, &elem->value); if (ret < 0) { free(elem); return ret; } list_add_tail(&elem->list, &arr->head); cnt++; while (isspace(*cursor)) cursor++; if (*cursor == ',') { cursor++; while (isspace(*cursor)) cursor++; } else if (*cursor == ']') break; else return -2; } *end = cursor + 1; return cnt; } static void __destroy_json_elements(json_array_t *arr) { struct list_head *pos, *tmp; json_element_t *elem; list_for_each_safe(pos, tmp, &arr->head) { elem = list_entry(pos, json_element_t, list); __destroy_json_value(&elem->value); free(elem); } } static int __parse_json_array(const char *cursor, const char **end, int depth, json_array_t *arr) { int ret; if (depth == JSON_DEPTH_LIMIT) return -3; INIT_LIST_HEAD(&arr->head); ret = __parse_json_elements(cursor, end, depth + 1, arr); if (ret < 0) { __destroy_json_elements(arr); return ret; } arr->size = ret; return 0; } static int __parse_json_value(const char *cursor, const char **end, int depth, json_value_t *val) { int ret; switch (*cursor) { case '\"': cursor++; ret = __json_string_length(cursor); if (ret < 0) return ret; val->value.string = (char *)malloc(ret + 1); if (!val->value.string) return -1; ret = __parse_json_string(cursor, end, val->value.string); if (ret < 0) { free(val->value.string); return ret; } val->type = JSON_VALUE_STRING; break; case '-': case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': case '8': case '9': ret = __parse_json_number(cursor, end, &val->value.number); if (ret < 0) return ret; val->type = JSON_VALUE_NUMBER; break; case '{': cursor++; ret = __parse_json_object(cursor, end, depth, &val->value.object); if (ret < 0) return ret; val->type = JSON_VALUE_OBJECT; break; case '[': cursor++; ret = __parse_json_array(cursor, end, depth, &val->value.array); if (ret < 0) return ret; val->type = JSON_VALUE_ARRAY; break; case 't': if (strncmp(cursor, "true", 4) != 0) return -2; *end = cursor + 4; val->type = JSON_VALUE_TRUE; break; case 'f': if (strncmp(cursor, "false", 5) != 0) return -2; *end = cursor + 5; val->type = JSON_VALUE_FALSE; break; case 'n': if (strncmp(cursor, "null", 4) != 0) return -2; *end = cursor + 4; val->type = JSON_VALUE_NULL; break; default: return -2; } return 0; } static void __destroy_json_value(json_value_t *val) { switch (val->type) { case JSON_VALUE_STRING: free(val->value.string); break; case JSON_VALUE_OBJECT: __destroy_json_members(&val->value.object); break; case JSON_VALUE_ARRAY: __destroy_json_elements(&val->value.array); break; } } json_value_t *json_value_parse(const char *cursor) { json_value_t *val; val = (json_value_t *)malloc(sizeof (json_value_t)); if (!val) return NULL; while (isspace(*cursor)) cursor++; if (__parse_json_value(cursor, &cursor, 0, val) >= 0) { while (isspace(*cursor)) cursor++; if (*cursor == '\0') return val; __destroy_json_value(val); } free(val); return NULL; } static void __move_json_value(json_value_t *src, json_value_t *dest) { switch (src->type) { case JSON_VALUE_STRING: dest->value.string = src->value.string; break; case JSON_VALUE_NUMBER: dest->value.number = src->value.number; break; case JSON_VALUE_OBJECT: INIT_LIST_HEAD(&dest->value.object.head); list_splice(&src->value.object.head, &dest->value.object.head); dest->value.object.root.rb_node = src->value.object.root.rb_node; dest->value.object.size = src->value.object.size; break; case JSON_VALUE_ARRAY: INIT_LIST_HEAD(&dest->value.array.head); list_splice(&src->value.array.head, &dest->value.array.head); dest->value.array.size = src->value.array.size; break; } dest->type = src->type; } static int __set_json_value(int type, va_list ap, json_value_t *val) { json_value_t *src; const char *str; int len; switch (type) { case 0: src = va_arg(ap, json_value_t *); __move_json_value(src, val); free(src); return 0; case JSON_VALUE_STRING: str = va_arg(ap, const char *); len = strlen(str); val->value.string = (char *)malloc(len + 1); if (!val->value.string) return -1; memcpy(val->value.string, str, len + 1); break; case JSON_VALUE_NUMBER: val->value.number = va_arg(ap, double); break; case JSON_VALUE_OBJECT: INIT_LIST_HEAD(&val->value.object.head); val->value.object.root.rb_node = NULL; val->value.object.size = 0; break; case JSON_VALUE_ARRAY: INIT_LIST_HEAD(&val->value.array.head); val->value.array.size = 0; break; } val->type = type; return 0; } json_value_t *json_value_create(int type, ...) { json_value_t *val; va_list ap; int ret; val = (json_value_t *)malloc(sizeof (json_value_t)); if (!val) return NULL; va_start(ap, type); ret = __set_json_value(type, ap, val); va_end(ap); if (ret < 0) { free(val); return NULL; } return val; } static int __copy_json_value(const json_value_t *src, json_value_t *dest); static int __copy_json_members(const json_object_t *src, json_object_t *dest) { struct list_head *pos; json_member_t *entry; json_member_t *memb; int len; list_for_each(pos, &src->head) { entry = list_entry(pos, json_member_t, list); len = strlen(entry->name); memb = (json_member_t *)malloc(offsetof(json_member_t, name) + len + 1); if (!memb) return -1; if (__copy_json_value(&entry->value, &memb->value) < 0) { free(memb); return -1; } memcpy(memb->name, entry->name, len + 1); __insert_json_member(memb, dest->head.prev, dest); } return src->size; } static int __copy_json_elements(const json_array_t *src, json_array_t *dest) { struct list_head *pos; json_element_t *entry; json_element_t *elem; list_for_each(pos, &src->head) { elem = (json_element_t *)malloc(sizeof (json_element_t)); if (!elem) return -1; entry = list_entry(pos, json_element_t, list); if (__copy_json_value(&entry->value, &elem->value) < 0) { free(elem); return -1; } list_add_tail(&elem->list, &dest->head); } return src->size; } static int __copy_json_value(const json_value_t *src, json_value_t *dest) { int len; switch (src->type) { case JSON_VALUE_STRING: len = strlen(src->value.string); dest->value.string = (char *)malloc(len + 1); if (!dest->value.string) return -1; memcpy(dest->value.string, src->value.string, len + 1); break; case JSON_VALUE_NUMBER: dest->value.number = src->value.number; break; case JSON_VALUE_OBJECT: INIT_LIST_HEAD(&dest->value.object.head); dest->value.object.root.rb_node = NULL; if (__copy_json_members(&src->value.object, &dest->value.object) < 0) { __destroy_json_members(&dest->value.object); return -1; } dest->value.object.size = src->value.object.size; break; case JSON_VALUE_ARRAY: INIT_LIST_HEAD(&dest->value.array.head); if (__copy_json_elements(&src->value.array, &dest->value.array) < 0) { __destroy_json_elements(&dest->value.array); return -1; } dest->value.array.size = src->value.array.size; break; } dest->type = src->type; return 0; } json_value_t *json_value_copy(const json_value_t *val) { json_value_t *copy; copy = (json_value_t *)malloc(sizeof (json_value_t)); if (!copy) return NULL; if (__copy_json_value(val, copy) < 0) { free(copy); return NULL; } return copy; } void json_value_destroy(json_value_t *val) { __destroy_json_value(val); free(val); } int json_value_type(const json_value_t *val) { return val->type; } const char *json_value_string(const json_value_t *val) { if (val->type != JSON_VALUE_STRING) return NULL; return val->value.string; } double json_value_number(const json_value_t *val) { if (val->type != JSON_VALUE_NUMBER) return NAN; return val->value.number; } json_object_t *json_value_object(const json_value_t *val) { if (val->type != JSON_VALUE_OBJECT) return NULL; return (json_object_t *)&val->value.object; } json_array_t *json_value_array(const json_value_t *val) { if (val->type != JSON_VALUE_ARRAY) return NULL; return (json_array_t *)&val->value.array; } const json_value_t *json_object_find(const char *name, const json_object_t *obj) { struct rb_node *p = obj->root.rb_node; json_member_t *memb; int n; while (p) { memb = rb_entry(p, json_member_t, rb); n = strcmp(name, memb->name); if (n < 0) p = p->rb_left; else if (n > 0) p = p->rb_right; else return &memb->value; } return NULL; } int json_object_size(const json_object_t *obj) { return obj->size; } const char *json_object_next_name(const char *name, const json_object_t *obj) { const struct list_head *pos; if (name) pos = &list_entry(name, json_member_t, name)->list; else pos = &obj->head; if (pos->next == &obj->head) return NULL; return list_entry(pos->next, json_member_t, list)->name; } const json_value_t *json_object_next_value(const json_value_t *val, const json_object_t *obj) { const struct list_head *pos; if (val) pos = &list_entry(val, json_member_t, value)->list; else pos = &obj->head; if (pos->next == &obj->head) return NULL; return &list_entry(pos->next, json_member_t, list)->value; } const char *json_object_prev_name(const char *name, const json_object_t *obj) { const struct list_head *pos; if (name) pos = &list_entry(name, json_member_t, name)->list; else pos = &obj->head; if (pos->prev == &obj->head) return NULL; return list_entry(pos->prev, json_member_t, list)->name; } const json_value_t *json_object_prev_value(const json_value_t *val, const json_object_t *obj) { const struct list_head *pos; if (val) pos = &list_entry(val, json_member_t, value)->list; else pos = &obj->head; if (pos->prev == &obj->head) return NULL; return &list_entry(pos->prev, json_member_t, list)->value; } const char *json_object_value_name(const json_value_t *val, const json_object_t *obj) { return list_entry(val, json_member_t, value)->name; } static const json_value_t *__json_object_insert(const char *name, int type, va_list ap, struct list_head *pos, json_object_t *obj) { json_member_t *memb; int len; len = strlen(name); memb = (json_member_t *)malloc(offsetof(json_member_t, name) + len + 1); if (!memb) return NULL; memcpy(memb->name, name, len + 1); if (__set_json_value(type, ap, &memb->value) < 0) { free(memb); return NULL; } __insert_json_member(memb, pos, obj); obj->size++; return &memb->value; } const json_value_t *json_object_append(json_object_t *obj, const char *name, int type, ...) { const json_value_t *val; va_list ap; va_start(ap, type); val = __json_object_insert(name, type, ap, obj->head.prev, obj); va_end(ap); return val; } const json_value_t *json_object_insert_after(const json_value_t *val, json_object_t *obj, const char *name, int type, ...) { struct list_head *pos; va_list ap; if (val) pos = &list_entry(val, json_member_t, value)->list; else pos = &obj->head; va_start(ap, type); val = __json_object_insert(name, type, ap, pos, obj); va_end(ap); return val; } const json_value_t *json_object_insert_before(const json_value_t *val, json_object_t *obj, const char *name, int type, ...) { struct list_head *pos; va_list ap; if (val) pos = &list_entry(val, json_member_t, value)->list; else pos = &obj->head; va_start(ap, type); val = __json_object_insert(name, type, ap, pos->prev, obj); va_end(ap); return val; } json_value_t *json_object_remove(const json_value_t *val, json_object_t *obj) { json_member_t *memb = list_entry(val, json_member_t, value); val = (json_value_t *)malloc(sizeof (json_value_t)); if (!val) return NULL; list_del(&memb->list); rb_erase(&memb->rb, &obj->root); obj->size--; __move_json_value(&memb->value, (json_value_t *)val); free(memb); return (json_value_t *)val; } int json_array_size(const json_array_t *arr) { return arr->size; } const json_value_t *json_array_next_value(const json_value_t *val, const json_array_t *arr) { const struct list_head *pos; if (val) pos = &list_entry(val, json_element_t, value)->list; else pos = &arr->head; if (pos->next == &arr->head) return NULL; return &list_entry(pos->next, json_element_t, list)->value; } const json_value_t *json_array_prev_value(const json_value_t *val, const json_array_t *arr) { const struct list_head *pos; if (val) pos = &list_entry(val, json_element_t, value)->list; else pos = &arr->head; if (pos->prev == &arr->head) return NULL; return &list_entry(pos->prev, json_element_t, list)->value; } static const json_value_t *__json_array_insert(int type, va_list ap, struct list_head *pos, json_array_t *arr) { json_element_t *elem; elem = (json_element_t *)malloc(sizeof (json_element_t)); if (!elem) return NULL; if (__set_json_value(type, ap, &elem->value) < 0) { free(elem); return NULL; } list_add(&elem->list, pos); arr->size++; return &elem->value; } const json_value_t *json_array_append(json_array_t *arr, int type, ...) { const json_value_t *val; va_list ap; va_start(ap, type); val = __json_array_insert(type, ap, arr->head.prev, arr); va_end(ap); return val; } const json_value_t *json_array_insert_after(const json_value_t *val, json_array_t *arr, int type, ...) { struct list_head *pos; va_list ap; if (val) pos = &list_entry(val, json_element_t, value)->list; else pos = &arr->head; va_start(ap, type); val = __json_array_insert(type, ap, pos, arr); va_end(ap); return val; } const json_value_t *json_array_insert_before(const json_value_t *val, json_array_t *arr, int type, ...) { struct list_head *pos; va_list ap; if (val) pos = &list_entry(val, json_element_t, value)->list; else pos = &arr->head; va_start(ap, type); val = __json_array_insert(type, ap, pos->prev, arr); va_end(ap); return val; } json_value_t *json_array_remove(const json_value_t *val, json_array_t *arr) { json_element_t *elem = list_entry(val, json_element_t, value); val = (json_value_t *)malloc(sizeof (json_value_t)); if (!val) return NULL; list_del(&elem->list); arr->size--; __move_json_value(&elem->value, (json_value_t *)val); free(elem); return (json_value_t *)val; } workflow-0.11.8/src/util/json_parser.h000066400000000000000000000076151476003635400177360ustar00rootroot00000000000000/* Copyright (c) 2022 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com) */ #ifndef _JSON_PARSER_H_ #define _JSON_PARSER_H_ #include #define JSON_VALUE_STRING 1 #define JSON_VALUE_NUMBER 2 #define JSON_VALUE_OBJECT 3 #define JSON_VALUE_ARRAY 4 #define JSON_VALUE_TRUE 5 #define JSON_VALUE_FALSE 6 #define JSON_VALUE_NULL 7 typedef struct __json_value json_value_t; typedef struct __json_object json_object_t; typedef struct __json_array json_array_t; #ifdef __cplusplus extern "C" { #endif json_value_t *json_value_parse(const char *text); json_value_t *json_value_create(int type, ...); json_value_t *json_value_copy(const json_value_t *val); void json_value_destroy(json_value_t *val); int json_value_type(const json_value_t *val); const char *json_value_string(const json_value_t *val); double json_value_number(const json_value_t *val); json_object_t *json_value_object(const json_value_t *val); json_array_t *json_value_array(const json_value_t *val); const json_value_t *json_object_find(const char *name, const json_object_t *obj); int json_object_size(const json_object_t *obj); const char *json_object_next_name(const char *name, const json_object_t *obj); const json_value_t *json_object_next_value(const json_value_t *val, const json_object_t *obj); const char *json_object_prev_name(const char *name, const json_object_t *obj); const json_value_t *json_object_prev_value(const json_value_t *val, const json_object_t *obj); const char *json_object_value_name(const json_value_t *val, const json_object_t *obj); const json_value_t *json_object_append(json_object_t *obj, const char *name, int type, ...); const json_value_t *json_object_insert_after(const json_value_t *val, json_object_t *obj, const char *name, int type, ...); const json_value_t *json_object_insert_before(const json_value_t *val, json_object_t *obj, const char *name, int type, ...); json_value_t *json_object_remove(const json_value_t *val, json_object_t *obj); int json_array_size(const json_array_t *arr); const json_value_t *json_array_next_value(const json_value_t *val, const json_array_t *arr); const json_value_t *json_array_prev_value(const json_value_t *val, const json_array_t *arr); const json_value_t *json_array_append(json_array_t *arry, int type, ...); const json_value_t *json_array_insert_after(const json_value_t *val, json_array_t *arr, int type, ...); const json_value_t *json_array_insert_before(const json_value_t *val, json_array_t *arr, int type, ...); json_value_t *json_array_remove(const json_value_t *val, json_array_t *arr); #ifdef __cplusplus } #endif #define json_object_for_each(name, val, obj) \ for (name = NULL, val = NULL; \ name = json_object_next_name(name, obj), \ val = json_object_next_value(val, obj), val; ) #define json_object_for_each_prev(name, val, obj) \ for (name = NULL, val = NULL; \ name = json_object_prev_name(name, obj), \ val = json_object_prev_value(val, obj), val; ) #define json_array_for_each(val, arr) \ for (val = NULL; val = json_array_next_value(val, arr), val; ) #define json_array_for_each_prev(val, arr) \ for (val = NULL; val = json_array_prev_value(val, arr), val; ) #endif workflow-0.11.8/src/util/xmake.lua000066400000000000000000000004101476003635400170320ustar00rootroot00000000000000target("util") set_kind("object") add_files("*.c") add_files("*.cc") remove_files("crc32c.c") target("kafka_util") if has_config("kafka") then set_kind("object") add_files("crc32c.c") else set_kind("phony") end workflow-0.11.8/src/xmake.lua000066400000000000000000000040031476003635400160570ustar00rootroot00000000000000includes("**/xmake.lua") after_build(function (target) local lib_dir = get_config("workflow_lib") if (not os.isdir(lib_dir)) then os.mkdir(lib_dir) end shared_suffix = "*.so" if is_plat("macosx") then shared_suffix = "*.dylib" end if target:is_static() then os.mv(path.join("$(projectdir)", target:targetdir(), "*.a"), lib_dir) else os.mv(path.join("$(projectdir)", target:targetdir(), shared_suffix), lib_dir) end end) target("workflow") set_kind("$(kind)") add_deps("client", "factory", "kernel", "manager", "nameservice", "protocol", "server", "util") on_load(function (package) local include_path = path.join(get_config("workflow_inc"), "workflow") if (not os.isdir(include_path)) then os.mkdir(include_path) end os.cp(path.join("$(projectdir)", "src/include/**.h"), include_path) os.cp(path.join("$(projectdir)", "src/include/**.inl"), include_path) end) after_clean(function (target) os.rm(get_config("workflow_inc")) os.rm(get_config("workflow_lib")) os.rm("$(buildir)") end) on_install(function (target) os.mkdir(path.join(target:installdir(), "include/workflow")) os.mkdir(path.join(target:installdir(), "lib")) os.cp(path.join(get_config("workflow_inc"), "workflow"), path.join(target:installdir(), "include")) shared_suffix = "*.so" if is_plat("macosx") then shared_suffix = "*.dylib" end if target:is_static() then os.cp(path.join(get_config("workflow_lib"), "*.a"), path.join(target:installdir(), "lib")) else os.cp(path.join(get_config("workflow_lib"), shared_suffix), path.join(target:installdir(), "lib")) end end) target("wfkafka") if has_config("kafka") then set_kind("$(kind)") add_deps("kafka_client", "kafka_factory", "kafka_protocol", "kafka_util", "workflow") else set_kind("phony") end workflow-0.11.8/test/000077500000000000000000000000001476003635400144425ustar00rootroot00000000000000workflow-0.11.8/test/CMakeLists.txt000066400000000000000000000035221476003635400172040ustar00rootroot00000000000000cmake_minimum_required(VERSION 3.6) set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "build type") project(workflow_test LANGUAGES C CXX ) find_library(LIBRT rt) find_package(OpenSSL REQUIRED) find_package(workflow REQUIRED CONFIG HINTS ..) include_directories(${OPENSSL_INCLUDE_DIR} ${WORKFLOW_INCLUDE_DIR}) link_directories(${WORKFLOW_LIB_DIR}) find_program(CMAKE_MEMORYCHECK_COMMAND valgrind) set(memcheck_command ${CMAKE_MEMORYCHECK_COMMAND} ${CMAKE_MEMORYCHECK_COMMAND_OPTIONS} --error-exitcode=1 --leak-check=full) add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND}) enable_testing() set(CXX_STD "c++14") find_package(GTest REQUIRED) if (WIN32) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /MP /wd4200") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /wd4200 /std:c++14") else () set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -fPIC -pipe -std=gnu90") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -fPIC -pipe -std=${CXX_STD} -fno-exceptions") if (APPLE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated-declarations") endif() endif () set(TEST_LIST task_unittest algo_unittest http_unittest redis_unittest mysql_unittest facilities_unittest graph_unittest memory_unittest upstream_unittest dns_unittest resource_unittest uriparser_unittest ) if (APPLE) set(WORKFLOW_LIB workflow pthread OpenSSL::SSL OpenSSL::Crypto) else () set(WORKFLOW_LIB workflow pthread OpenSSL::SSL OpenSSL::Crypto ${LIBRT}) endif () foreach(src ${TEST_LIST}) add_executable(${src} EXCLUDE_FROM_ALL ${src}.cc) target_link_libraries(${src} ${WORKFLOW_LIB} GTest::GTest GTest::Main) add_test(${src} ${src}) add_dependencies(check ${src}) endforeach() if (NOT ${CMAKE_MEMORYCHECK_COMMAND} STREQUAL "CMAKE_MEMORYCHECK_COMMAND-NOTFOUND") foreach(src ${TEST_LIST}) add_test(${src}-memory-check ${memcheck_command} ./${src}) endforeach() endif () workflow-0.11.8/test/GNUmakefile000066400000000000000000000014361476003635400165200ustar00rootroot00000000000000ROOT_DIR := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) ALL_TARGETS := all check clean MAKE_FILE := Makefile DEFAULT_BUILD_DIR := build.cmake BUILD_DIR := $(shell if [ -f $(MAKE_FILE) ]; then echo "."; else echo $(DEFAULT_BUILD_DIR); fi) CMAKE3 := $(shell if which cmake3 ; then echo cmake3; else echo cmake; fi;) .PHONY: $(ALL_TARGETS) all: mkdir -p $(BUILD_DIR) ifeq ($(DEBUG),y) cd $(BUILD_DIR) && $(CMAKE3) -D CMAKE_BUILD_TYPE=Debug $(ROOT_DIR) else cd $(BUILD_DIR) && $(CMAKE3) $(ROOT_DIR) endif make -C $(BUILD_DIR) -f Makefile check: mkdir -p $(BUILD_DIR) cd $(BUILD_DIR) && $(CMAKE3) $(ROOT_DIR) make -C $(BUILD_DIR) check CTEST_OUTPUT_ON_FAILURE=1 clean: ifeq ($(MAKE_FILE), $(wildcard $(MAKE_FILE))) -make -f Makefile clean endif rm -rf $(DEFAULT_BUILD_DIR) workflow-0.11.8/test/algo_unittest.cc000066400000000000000000000042761476003635400176430ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #include #include #include #include #include "workflow/WFAlgoTaskFactory.h" static void __arr_init(int *arr, int n) { srand(time(NULL)); for (int i = 0; i < n; i++) arr[i] = rand() % 65536; } static void __arr_check(int *arr, int n) { for (int i = 1; i < n; i++) EXPECT_LE(arr[i - 1], arr[i]); } TEST(algo_unittest, sort) { static constexpr int n = 100000; int *arr = new int[n]; __arr_init(arr, n); std::mutex mutex; std::condition_variable cond; bool done = false; auto *task = WFAlgoTaskFactory::create_sort_task("sort", arr, arr + n, [&mutex, &cond, &done](WFSortTask *task) { int *first = task->get_input()->first; int *last = task->get_input()->last; __arr_check(first, last - first); mutex.lock(); done = true; mutex.unlock(); cond.notify_one(); }); task->start(); std::unique_lock lock(mutex); while (!done) cond.wait(lock); lock.unlock(); delete []arr; } TEST(algo_unittest, parallel_sort) { static constexpr int n = 100000; int *arr = new int[n]; __arr_init(arr, n); std::mutex mutex; std::condition_variable cond; bool done = false; auto *task = WFAlgoTaskFactory::create_psort_task("psort", arr, arr + n, [&mutex, &cond, &done](WFSortTask *task) { int *first = task->get_input()->first; int *last = task->get_input()->last; __arr_check(first, last - first); mutex.lock(); done = true; mutex.unlock(); cond.notify_one(); }); task->start(); std::unique_lock lock(mutex); while (!done) cond.wait(lock); lock.unlock(); delete []arr; } workflow-0.11.8/test/dns_unittest.cc000066400000000000000000000053321476003635400174770ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Liu Kai (liukaidx@sogou-inc.com) */ #include #include #include "workflow/WFTaskFactory.h" #include "workflow/WFDnsClient.h" #define RETRY_MAX 3 TEST(dns_unittest, WFDnsTaskCreate1) { std::string url = "dns://119.29.29.29/www.sogou.com"; auto *task = WFTaskFactory::create_dns_task(url, 0, NULL); task->dismiss(); } TEST(dns_unittest, WFDnsTaskCreate2) { std::string url = "http://119.29.29.29:dns/"; std::promise done; auto *task = WFTaskFactory::create_dns_task(url, 0, [&done] (WFDnsTask *task) { done.set_value(); }); task->start(); done.get_future().get(); } TEST(dns_unittest, WFDnsTask) { std::string url = "dns://119.29.29.29/www.sogou.com"; unsigned short req_id = 0x1234; std::promise done; auto *task = WFTaskFactory::create_dns_task(url, RETRY_MAX, [&done, req_id] (WFDnsTask *task) { int state = task->get_state(); if (state == WFT_STATE_SUCCESS) { unsigned short resp_id = task->get_resp()->get_id(); EXPECT_TRUE(req_id == resp_id); } done.set_value(); }); auto *req = task->get_req(); req->set_id(req_id); req->set_rd(1); req->set_question_type(DNS_TYPE_A); task->start(); auto fut = done.get_future(); fut.get(); } TEST(dns_unittest, WFDnsClientInit1) { WFDnsClient client; if (client.init("bad") >= 0) client.deinit(); } TEST(dns_unittest, WFDnsClientInit2) { WFDnsClient client; int ret = client.init("0.0.0.0,0.0.0.1:1,dns://0.0.0.2,dnss://0.0.0.3"); EXPECT_TRUE(ret >= 0); client.deinit(); } TEST(dns_unittest, WFDnsClient) { unsigned short req_id = 0x4321; std::promise done; WFDnsClient client; client.init("dns://119.29.29.29/"); auto *task = client.create_dns_task("www.sogou.com", [&done, req_id] (WFDnsTask *task) { int state = task->get_state(); if (state == WFT_STATE_SUCCESS) { unsigned short resp_id = task->get_resp()->get_id(); EXPECT_TRUE(req_id == resp_id); } done.set_value(); }); client.deinit(); auto *req = task->get_req(); req->set_id(req_id); task->start(); auto fut = done.get_future(); fut.get(); } int main(int argc, char *argv[]) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } workflow-0.11.8/test/facilities_unittest.cc000066400000000000000000000071341476003635400210310ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Li Yingxin (liyingxin@sogou-inc.com) */ #include #include #include #include #include #include "workflow/WFFacilities.h" #include "workflow/HttpUtil.h" #define GET_CURRENT_MICRO std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()).count() TEST(facilities_unittest, usleep) { int64_t st = GET_CURRENT_MICRO; WFFacilities::usleep(1000000); int64_t ed = GET_CURRENT_MICRO; EXPECT_LE(ed - st, 10000000) << "usleep too slow"; } TEST(facilities_unittest, async_usleep) { int64_t st = GET_CURRENT_MICRO; WFFacilities::async_usleep(1000000).wait(); int64_t ed = GET_CURRENT_MICRO; EXPECT_LE(ed - st, 10000000) << "async_usleep too slow"; } TEST(facilities_unittest, request) { protocol::HttpRequest req; req.set_method(HttpMethodGet); req.set_http_version("HTTP/1.1"); req.set_request_uri("/"); req.set_header_pair("Host", "github.com"); auto res = WFFacilities::request(TT_TCP, "http://github.com", std::move(req), 0); //EXPECT_EQ(res.task_state, WFT_STATE_SUCCESS); if (res.task_state == WFT_STATE_SUCCESS) { auto code = atoi(res.resp.get_status_code()); EXPECT_TRUE(code == HttpStatusOK || code == HttpStatusMovedPermanently || code == HttpStatusFound || code == HttpStatusSeeOther || code == HttpStatusTemporaryRedirect || code == HttpStatusPermanentRedirect); } } TEST(facilities_unittest, async_request) { protocol::HttpRequest req; req.set_method(HttpMethodGet); req.set_http_version("HTTP/1.1"); req.set_request_uri("/"); req.set_header_pair("Host", "github.com"); auto res = WFFacilities::request(TT_TCP_SSL, "https://github.com", std::move(req), 0); //EXPECT_EQ(res.task_state, WFT_STATE_SUCCESS); if (res.task_state == WFT_STATE_SUCCESS) { auto code = atoi(res.resp.get_status_code()); EXPECT_TRUE(code == HttpStatusOK || code == HttpStatusMovedPermanently || code == HttpStatusFound || code == HttpStatusSeeOther || code == HttpStatusTemporaryRedirect || code == HttpStatusPermanentRedirect); } } TEST(facilities_unittest, fileIO) { uint64_t data = 0x1234; ssize_t sz; int fd = open("test.test", O_RDWR | O_TRUNC | O_CREAT, 0644); sz = WFFacilities::async_pwrite(fd, &data, 8, 0).get(); EXPECT_EQ(sz, 8); data = 0; sz = WFFacilities::async_pread(fd, &data, 8, 0).get(); EXPECT_EQ(sz, 8); EXPECT_EQ(data, 0x1234); close(fd); } static inline void f(int i, WFFacilities::WaitGroup *wg) { wg->done(); } TEST(facilities_unittest, WaitGroup) { WFFacilities::WaitGroup wg(100); for (int i = 0; i < 100; i++) WFFacilities::go("facilities", f, i, &wg); wg.wait(); WFFacilities::WaitGroup wg2(-100); wg2.wait(); WFFacilities::WaitGroup wg3(0); wg3.wait(); } #if OPENSSL_VERSION_NUMBER >= 0x10100000L #include int main(int argc, char* argv[]) { OPENSSL_init_ssl(0, 0); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } #endif workflow-0.11.8/test/graph_unittest.cc000066400000000000000000000040741476003635400200160ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Liu Yang (liuyang216492@sogou-inc.com) */ #include #include #include "workflow/WFTaskFactory.h" #include "workflow/WFFacilities.h" static SubTask *create_task(int& target) { static std::atomic generator; return WFTaskFactory::create_timer_task(0, [&](WFTimerTask *) { target = generator++; }); } TEST(graph_unittest, WFGraphTask1) { WFFacilities::WaitGroup wait_group(1); auto graph = WFTaskFactory::create_graph_task([&wait_group](WFGraphTask *){ wait_group.done(); }); int ta, tb, tc, td; auto& a = graph->create_graph_node(create_task(ta)); auto& b = graph->create_graph_node(create_task(tb)); auto& c = graph->create_graph_node(create_task(tc)); auto& d = graph->create_graph_node(create_task(td)); a --> b <-- c --> d --> a; c --> a; graph->start(); wait_group.wait(); EXPECT_LT(ta, tb); EXPECT_LT(tc, tb); EXPECT_LT(tc, td); EXPECT_LT(td, ta); EXPECT_LT(tc, ta); } TEST(graph_unittest, WFGraphTask2) { WFFacilities::WaitGroup wait_group(1); auto graph = WFTaskFactory::create_graph_task([&wait_group](WFGraphTask *){ wait_group.done(); }); constexpr int N = 4096 - 1; auto target = new int[N]; auto node = new WFGraphNode *[N]; for (int i = 0; i < N; i++) node[i] = &graph->create_graph_node(create_task(target[i])); for (int i = 1; i < N; i++) node[i]->precede(*node[(i - 1) / 2]); graph->start(); wait_group.wait(); for (int i = 1; i < N; i++) EXPECT_LT(target[i], target[(i - 1) / 2]); delete[] target; delete[] node; } workflow-0.11.8/test/http_unittest.cc000066400000000000000000000164701476003635400176770ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #include #include #include #include #include "workflow/WFTaskFactory.h" #include "workflow/WFOperator.h" #include "workflow/WFHttpServer.h" #include "workflow/HttpUtil.h" #define RETRY_MAX 3 static void __http_process(WFHttpTask *task) { auto *req = task->get_req(); auto *resp = task->get_resp(); EXPECT_TRUE(strcmp(req->get_request_uri(), "/test") == 0); resp->add_header_pair("Content-Type", "text/plain"); } TEST(http_unittest, WFHttpTask1) { std::mutex mutex; std::condition_variable cond; bool done = false; auto *task = WFTaskFactory::create_http_task("http://github.com", 0, RETRY_MAX, [&mutex, &cond, &done](WFHttpTask *task) { auto state = task->get_state(); //EXPECT_EQ(state, WFT_STATE_SUCCESS); if (state == WFT_STATE_SUCCESS) { auto code = atoi(task->get_resp()->get_status_code()); EXPECT_TRUE(code == HttpStatusOK || code == HttpStatusMovedPermanently || code == HttpStatusFound || code == HttpStatusSeeOther || code == HttpStatusTemporaryRedirect || code == HttpStatusPermanentRedirect); } mutex.lock(); done = true; mutex.unlock(); cond.notify_one(); }); task->start(); std::unique_lock lock(mutex); while (!done) cond.wait(lock); lock.unlock(); } TEST(http_unittest, WFHttpTask2) { std::mutex mutex; std::condition_variable cond; bool done = false; auto *task = WFTaskFactory::create_http_task("http://github.com", 1, RETRY_MAX, [&mutex, &cond, &done](WFHttpTask *task) { auto state = task->get_state(); //EXPECT_EQ(state, WFT_STATE_SUCCESS); if (state == WFT_STATE_SUCCESS) { auto code = atoi(task->get_resp()->get_status_code()); EXPECT_TRUE(code == HttpStatusOK || code == HttpStatusMovedPermanently || code == HttpStatusFound || code == HttpStatusSeeOther || code == HttpStatusTemporaryRedirect || code == HttpStatusPermanentRedirect); } mutex.lock(); done = true; mutex.unlock(); cond.notify_one(); }); task->start(); std::unique_lock lock(mutex); while (!done) cond.wait(lock); lock.unlock(); } TEST(http_unittest, WFHttpTask3) { FILE *f; f = fopen("server.crt", "w"); fputs(R"( -----BEGIN CERTIFICATE----- MIIDrjCCApYCCQCzDnhp/eqaRTANBgkqhkiG9w0BAQUFADCBmDELMAkGA1UEBhMC Q04xEDAOBgNVBAgMB0JlaWppbmcxEDAOBgNVBAcMB0JlaWppbmcxFzAVBgNVBAoM DlNvZ291LmNvbSBJbmMuMRYwFAYDVQQLDA13d3cuc29nb3UuY29tMQ8wDQYDVQQD DAZ4aWVoYW4xIzAhBgkqhkiG9w0BCQEWFHhpZWhhbkBzb2dvdS1pbmMuY29tMB4X DTE5MDYxMTA5MjQxNloXDTIwMDYxMDA5MjQxNlowgZgxCzAJBgNVBAYTAkNOMRAw DgYDVQQIDAdCZWlqaW5nMRAwDgYDVQQHDAdCZWlqaW5nMRcwFQYDVQQKDA5Tb2dv dS5jb20gSW5jLjEWMBQGA1UECwwNd3d3LnNvZ291LmNvbTEPMA0GA1UEAwwGeGll aGFuMSMwIQYJKoZIhvcNAQkBFhR4aWVoYW5Ac29nb3UtaW5jLmNvbTCCASIwDQYJ KoZIhvcNAQEBBQADggEPADCCAQoCggEBALB6E1+lnuey24j+BwcD21h5t/xD+K6I thHiyT3S8fztAd+BfyphT+KLhbHbJFUaz7tfoV8lyBDdyVlgfwlCLyCp2sNcaCwg TF+XjTWOkDtg5+rCgoHRUjLNIJ2auO/5780DZcaL41gwzAu5rwE3sOifIZ4XI5WO 6zrd5MUFhpHy91Sz1sxcCLXwQEgPDsa10/6k5bSd8xYP29yZ80lZeJ++5fgOf/AU JkANXLjsHnfOFV42Je/6EEcqe0YM6kjA9d4d5TS+To5YPfObTTR21Cey4RD5Ijjg 4/VGdtI6tDWa3+N/CVVc8CKLVGNCVyAGWoBXCZuzlfex9Z0jtY2dd1cCAwEAATAN BgkqhkiG9w0BAQUFAAOCAQEAoLALHvGt0xCsDsYxxQ3biioPa2djT5jN8/QI17QF 7C+0IdFEJi6dwF/O0rPgHbVSMZB7pPl5gx/rC4bWg9CYvZmlptmDJym+SpR0CBLC /LXEFsA7VmkdAiG6CHLtg1uZy0LTN0sRMdLNIetm6PBcnr3JEB8erayRaYy1Qk7d 6O+3KexviFX/dAJRj59AIYXoMwji2ZYowXH+InNVF8UEunynJGURJJGQXFh0R18Q SniEJZux/WkxaOkqMBHtXtdkowpSMjn/RUA5dVu5Zjyf8LL9cjBmyKMxLXKeQeKK 0ylFmFZxY8GawFdCq4XUKzSuLw4/orfuKn/ViSSixuXL5A== -----END CERTIFICATE----- )", f); fclose(f); f = fopen("server.key", "w"); fputs(R"( -----BEGIN RSA PRIVATE KEY----- MIIEogIBAAKCAQEAsHoTX6We57LbiP4HBwPbWHm3/EP4roi2EeLJPdLx/O0B34F/ KmFP4ouFsdskVRrPu1+hXyXIEN3JWWB/CUIvIKnaw1xoLCBMX5eNNY6QO2Dn6sKC gdFSMs0gnZq47/nvzQNlxovjWDDMC7mvATew6J8hnhcjlY7rOt3kxQWGkfL3VLPW zFwItfBASA8OxrXT/qTltJ3zFg/b3JnzSVl4n77l+A5/8BQmQA1cuOwed84VXjYl 7/oQRyp7RgzqSMD13h3lNL5Ojlg985tNNHbUJ7LhEPkiOODj9UZ20jq0NZrf438J VVzwIotUY0JXIAZagFcJm7OV97H1nSO1jZ13VwIDAQABAoIBAFPW+yNCjLaouzFe 9bm4dFmZIfZf2GIaotzmcBLGB57QfkZPwDlDF++Ztz9iy+T+otfyu7h3O4//veuP M2sTnU4YQ8zyNq9X/NChMD3UZ+M9y5A1Lkk8R5/I4gjd+6ROikVMqupjhPNd42Ji qaiba5loGFGBzq77wfcqece8M01cZTnCtZ5ZdFrxzWWd9EaKhXf6Mkibaf6Y4/Oi GVvhqKK7Yv4f+xX85GnZuBv8hau6nCfiC/5zYKm8SiAoWE1TikMZGd2+bwAE1COh qeVJyevA7XcP8z+dtqb0hBHqlm0DTyVmu/cuHAZHxYms7VvJ2isWKI4gl1MY3zD3 ODHEeHECgYEA36eVhGCAQeAP3eTtEq1dcSSsb3bEKTpZGxj6BT89HRp0qcw/dKQV oITXMeSJpIRR879mi5FBFHlvTb0xkI96O5fXuAz/A7hSOtZpiJ4G3tAEplbPJhmB 3km3syRXqXuv8m38Zjb9FOgu7D/OSWYe8QGWM/rrDjgBfJNveKlWn/kCgYEAyf/R heAvuFxqf77XRzjBhil1N09f9mw8yagFritNyy8Wb+SlNSHIBZ9WSKVdVxyA4GOe A/0yAY7r9i/Y1sMnCt0kL5UEwY2xlbA+Ld/B/5MjEN4mP9g5a2goj75w7CBT/YLh dAfNwN08wsTNl/53tovhqz1uvU+muAWQnAgURc8CgYAjqKOFHKG2XxQIi+RkkvGQ BYncp7H05NGqKVxLk96ZkktBe0guv66XDjcFRGvRqCss0rp1zC31JrthSKXrZ4TU lYwWUzQhkrTBnsfquU9dHQtwvex/JZf4Kga48DVt10OhQnn4jhHh0HcSwcWRHFAY muko1nu9o55RD2y5bz5ZeQKBgFfzec/3n+9+1aQPfP52uNRogq/1cIwD7qfC7844 7qNUOkm33TL4JXZFPTVeQvjl4TtSRH/qI3bIOvczOA+yYvJ4/QN2t95qinLpjPk+ XuKftvnmL/NGeyHH9Tk5K0O0g71y2iVCLJUX/xeyxu2yD3+9AiIkGm51GtsvGRrG 7cTDAoGAIlzSgiMSMkRUpzyJYvRd5o+Bt+v+SHDni40XrfZqc4cmh8MVPdVkNMFi a/7MiJf+tw5lRG/Oks0pNOvFIpTXi8ncxW9tgQfy2hN6LMGD7uIu/X9uMJmwvNtj KZ1lOvb+vi3TLrQf4tfBekrXXe5tZK40QSJ7UdtY7HHrrbAXU+8= -----END RSA PRIVATE KEY----- )", f); fclose(f); WFHttpServer http_server(__http_process); EXPECT_TRUE(http_server.start("127.0.0.1", 8811) == 0) << "http server start failed"; WFHttpServer https_server(__http_process); EXPECT_TRUE(https_server.start("127.0.0.1", 8822, "server.crt", "server.key") == 0) << "https server start failed"; std::mutex mutex; std::condition_variable cond; bool done = false; auto cb = [](WFHttpTask *task) { auto state = task->get_state(); EXPECT_EQ(state, WFT_STATE_SUCCESS); if (state == WFT_STATE_SUCCESS) { auto *resp = task->get_resp(); auto code = atoi(resp->get_status_code()); EXPECT_EQ(code, HttpStatusOK); protocol::HttpHeaderCursor cursor(resp); std::string content_type; EXPECT_TRUE(cursor.find("Content-Type", content_type)); EXPECT_TRUE(content_type == "text/plain"); } }; auto *A = WFTaskFactory::create_http_task("http://127.0.0.1:8811/test", 0, RETRY_MAX, cb); auto *B = WFTaskFactory::create_http_task("https://127.0.0.1:8822/test", 0, RETRY_MAX, cb); auto& flow = *A > B; flow.set_callback([&mutex, &cond, &done](const SeriesWork *series) { mutex.lock(); done = true; mutex.unlock(); cond.notify_one(); }); flow.start(); std::unique_lock lock(mutex); while (!done) cond.wait(lock); lock.unlock(); http_server.stop(); https_server.stop(); } #if OPENSSL_VERSION_NUMBER >= 0x10100000L #include int main(int argc, char* argv[]) { OPENSSL_init_ssl(0, 0); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } #endif workflow-0.11.8/test/memory_unittest.cc000066400000000000000000000041511476003635400202210ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Liu Yang (liuyang216492@sogou-inc.com) */ #include #include #include "workflow/WFTaskFactory.h" TEST(memory_unittest, dismiss) { std::vector tasks; auto *http_task = WFTaskFactory::create_http_task("http://www.sogou.com", 0, 0, nullptr); tasks.push_back(http_task); auto *redis_task = WFTaskFactory::create_redis_task("redis://username:password@127.0.0.1:6676/1", 0, nullptr); tasks.push_back(redis_task); auto *mysql_task = WFTaskFactory::create_mysql_task("mysql://username:password@127.0.0.1:8899/db", 0, nullptr); tasks.push_back(mysql_task); auto *timer_task = WFTaskFactory::create_timer_task(0, nullptr); tasks.push_back(timer_task); auto *counter_task = WFTaskFactory::create_counter_task("", 1, nullptr); tasks.push_back(counter_task); auto *go_task = WFTaskFactory::create_go_task("", [](){}); tasks.push_back(go_task); auto *thread_task = WFThreadTaskFactory::create_thread_task("", [](int *, int *){}, nullptr); tasks.push_back(thread_task); auto *graph_task = WFTaskFactory::create_graph_task(nullptr); auto& node_a = graph_task->create_graph_node(WFTaskFactory::create_timer_task(0, nullptr)); auto& node_b = graph_task->create_graph_node(WFTaskFactory::create_timer_task(0, nullptr)); node_a -->-- node_b; tasks.push_back(graph_task); auto *parallel_work = Workflow::create_parallel_work(nullptr); for (auto *task : tasks) { auto *series_work = Workflow::create_series_work(task, nullptr); parallel_work->add_series(series_work); } parallel_work->dismiss(); } workflow-0.11.8/test/mysql_unittest.cc000066400000000000000000000040021476003635400200510ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #include #include #include #include #include "workflow/WFTaskFactory.h" #include "workflow/WFMySQLServer.h" #define RETRY_MAX 3 static void __mysql_process(WFMySQLTask *task) { //auto *req = task->get_req(); auto *resp = task->get_resp(); resp->set_ok_packet(); } static void test_client(const char *url, const char *sql, std::mutex& mutex, std::condition_variable& cond, bool& done) { auto *task = WFTaskFactory::create_mysql_task(url, RETRY_MAX, [&mutex, &cond, &done](WFMySQLTask *task) { auto state = task->get_state(); EXPECT_EQ(state, WFT_STATE_SUCCESS); mutex.lock(); done = true; mutex.unlock(); cond.notify_one(); }); task->get_req()->set_query(sql); task->start(); } TEST(mysql_unittest, WFMySQLTask1) { std::mutex mutex; std::condition_variable cond; bool done = false; WFMySQLServer server(__mysql_process); EXPECT_TRUE(server.start("127.0.0.1", 8899) == 0) << "server start failed"; test_client("mysql://testuser:testpass@127.0.0.1:8899/testdb", "select * from testtable limit 3", mutex, cond, done); std::unique_lock lock(mutex); while (!done) cond.wait(lock); lock.unlock(); server.stop(); } #if OPENSSL_VERSION_NUMBER >= 0x10100000L #include int main(int argc, char* argv[]) { OPENSSL_init_ssl(0, 0); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } #endif workflow-0.11.8/test/redis_unittest.cc000066400000000000000000000077101476003635400200230ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #include #include #include #include #include #include #include #include "workflow/WFTaskFactory.h" #include "workflow/WFRedisServer.h" #include "workflow/WFOperator.h" #define RETRY_MAX 3 static void __redis_process(WFRedisTask *task) { auto *req = task->get_req(); auto *resp = task->get_resp(); EXPECT_TRUE(req->parse_success()); std::string cmd; std::vector params; protocol::RedisValue val; EXPECT_TRUE(req->get_command(cmd)); EXPECT_TRUE(req->get_params(params)); if (strcasecmp(cmd.c_str(), "SET") == 0) { EXPECT_EQ(params.size(), 2); EXPECT_TRUE(params[0] == "testkey"); EXPECT_TRUE(params[1] == "testvalue"); val.set_status("OK"); } else if (strcasecmp(cmd.c_str(), "GET") == 0) { EXPECT_EQ(params.size(), 1); val.set_string("testvalue"); } else if (strcasecmp(cmd.c_str(), "DEL") == 0) { EXPECT_EQ(params.size(), 1); EXPECT_TRUE(params[0] == "testkey"); val.set_status("OK"); } else if (strcasecmp(cmd.c_str(), "SELECT") == 0) { EXPECT_EQ(params.size(), 1); EXPECT_TRUE(params[0] == "6"); val.set_status("OK"); } else if (strcasecmp(cmd.c_str(), "AUTH") == 0) { EXPECT_EQ(params.size(), 1); EXPECT_TRUE(params[0] == "testpass"); val.set_status("OK"); } else { EXPECT_TRUE(0) << "Command Not Support"; val.set_error("Command Not Support"); } resp->set_result(val); } static void test_client(const char *url, std::mutex& mutex, std::condition_variable& cond, bool& done) { auto&& set_cb = [](WFRedisTask *task) { auto state = task->get_state(); auto *resp = task->get_resp(); protocol::RedisValue val; EXPECT_EQ(state, WFT_STATE_SUCCESS); EXPECT_TRUE(resp->parse_success()); resp->get_result(val); EXPECT_TRUE(val.is_ok()); }; auto&& get_cb = [](WFRedisTask *task) { auto state = task->get_state(); auto *resp = task->get_resp(); protocol::RedisValue val; EXPECT_EQ(state, WFT_STATE_SUCCESS); EXPECT_TRUE(resp->parse_success()); resp->get_result(val); EXPECT_TRUE(val.is_string()); EXPECT_TRUE(val.string_value() == "testvalue"); }; auto&& del_cb = [](WFRedisTask *task) { auto state = task->get_state(); auto *resp = task->get_resp(); protocol::RedisValue val; EXPECT_EQ(state, WFT_STATE_SUCCESS); EXPECT_TRUE(resp->parse_success()); resp->get_result(val); EXPECT_TRUE(val.is_ok()); }; auto *A = WFTaskFactory::create_redis_task(url, RETRY_MAX, std::move(set_cb)); auto *B = WFTaskFactory::create_redis_task(url, RETRY_MAX, std::move(get_cb)); auto *C = WFTaskFactory::create_redis_task(url, RETRY_MAX, std::move(del_cb)); auto& flow = *A > B > C; A->get_req()->set_request("SET", {"testkey", "testvalue"}); B->get_req()->set_request("GET", {"testkey"}); C->get_req()->set_request("DEL", {"testkey"}); flow.set_callback([&mutex, &cond, &done](const SeriesWork *series) { mutex.lock(); done = true; mutex.unlock(); cond.notify_one(); }); flow.start(); } TEST(redis_unittest, WFRedisTask1) { std::mutex mutex; std::condition_variable cond; bool done = false; WFRedisServer server(__redis_process); EXPECT_TRUE(server.start("127.0.0.1", 6677) == 0) << "server start failed"; test_client("redis://:testpass@127.0.0.1:6677/6", mutex, cond, done); std::unique_lock lock(mutex); while (!done) cond.wait(lock); lock.unlock(); server.stop(); } workflow-0.11.8/test/resource_unittest.cc000066400000000000000000000032111476003635400205340ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Li Yingxin (liyingxin@sogou-inc.com) */ #include #include #include #include #include "workflow/WFTask.h" #include "workflow/WFTaskFactory.h" #include "workflow/WFResourcePool.h" #include "workflow/WFFacilities.h" TEST(resource_unittest, resource_pool) { int res_concurrency = 3; int task_concurrency = 10; const char *words[3] = {"workflow", "srpc", "pyworkflow"}; WFResourcePool res_pool((void * const*)words, res_concurrency); WFFacilities::WaitGroup wg(task_concurrency); for (int i = 0; i < task_concurrency; i++) { auto *user_task = WFTaskFactory::create_timer_task(0, [&wg, &res_pool](WFTimerTask *task) { uint64_t id = (uint64_t)series_of(task)->get_context(); printf("task-%lu get [%s]\n", id, (char *)task->user_data); res_pool.post(task->user_data); wg.done(); }); auto *cond = res_pool.get(user_task, &user_task->user_data); SeriesWork *series = Workflow::create_series_work(cond, nullptr); series->set_context(reinterpret_cast(i)); series->start(); } wg.wait(); } workflow-0.11.8/test/task_unittest.cc000066400000000000000000000164051476003635400176600ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Wu Jiaxu (wujiaxu@sogou-inc.com) */ #include #include #include #include #include #include #include #include #include #include #include #include "workflow/WFTaskFactory.h" #define GET_CURRENT_MICRO std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()).count() TEST(task_unittest, WFTimerTask) { std::mutex mutex; std::condition_variable cond; bool done = false; auto *task = WFTaskFactory::create_timer_task(1000000, [&mutex, &cond, &done](WFTimerTask *task) { EXPECT_EQ(task->get_state(), WFT_STATE_SUCCESS); mutex.lock(); done = true; mutex.unlock(); cond.notify_one(); }); int64_t st = GET_CURRENT_MICRO; task->start(); std::unique_lock lock(mutex); while (!done) cond.wait(lock); lock.unlock(); int64_t ed = GET_CURRENT_MICRO; EXPECT_LE(ed - st, 10000000) << "Timer Task too slow"; } TEST(task_unittest, WFCounterTask1) { std::mutex mutex; std::condition_variable cond; bool done = false; auto *task = WFTaskFactory::create_counter_task("abc", 2, [&mutex, &cond, &done](WFCounterTask *task) { auto state = task->get_state(); EXPECT_EQ(state, WFT_STATE_SUCCESS); if (state == WFT_STATE_SUCCESS) { WFTaskFactory::count_by_name("abc", 0); task->count(); WFTaskFactory::count_by_name("abc", 1); } mutex.lock(); done = true; mutex.unlock(); cond.notify_one(); }); task->start(); for (int i = 0; i < 100; i++) { WFTaskFactory::count_by_name("abc"); WFTaskFactory::count_by_name("abc"); } std::unique_lock lock(mutex); while (!done) cond.wait(lock); lock.unlock(); } TEST(task_unittest, WFCounterTask2) { std::mutex mutex; std::condition_variable cond; bool done = false; auto *task = WFTaskFactory::create_counter_task("def", 2, [&mutex, &cond, &done](WFCounterTask *task) { auto state = task->get_state(); EXPECT_EQ(state, WFT_STATE_SUCCESS); if (state == WFT_STATE_SUCCESS) { WFTaskFactory::count_by_name("def", 0); task->count(); WFTaskFactory::count_by_name("def", 1); } mutex.lock(); done = true; mutex.unlock(); cond.notify_one(); }); task->count(); task->start(); task->count(); std::unique_lock lock(mutex); while (!done) cond.wait(lock); lock.unlock(); } TEST(task_unittest, WFGoTask) { srand(time(NULL)); std::mutex mutex; std::condition_variable cond; bool done = false; int target = rand() % 1024; int edit_inner = -1; auto&& f = [&mutex, &cond, &done, target, &edit_inner](int id) { EXPECT_EQ(target, id); edit_inner = 100; mutex.lock(); done = true; mutex.unlock(); cond.notify_one(); }; WFGoTask *task = WFTaskFactory::create_go_task("go", std::move(f), target); task->start(); std::unique_lock lock(mutex); while (!done) cond.wait(lock); lock.unlock(); EXPECT_EQ(edit_inner, 100); } TEST(task_unittest, WFThreadTask) { std::mutex mutex; std::condition_variable cond; bool done = false; using MyTaskIn = std::pair; using MyTaskOut = int; using MyFactory = WFThreadTaskFactory; using MyTask = WFThreadTask; auto&& calc_multi = [](MyTaskIn *in, MyTaskOut *out) { *out = in->first * in->second; }; auto *task = MyFactory::create_thread_task("calc", std::move(calc_multi), [&mutex, &cond, &done](MyTask *task) { auto state = task->get_state(); EXPECT_EQ(state, WFT_STATE_SUCCESS); if (state == WFT_STATE_SUCCESS) { auto *in = task->get_input(); auto *out = task->get_output(); EXPECT_EQ(in->first * in->second, *out); } mutex.lock(); done = true; mutex.unlock(); cond.notify_one(); }); task->start(); std::unique_lock lock(mutex); while (!done) cond.wait(lock); lock.unlock(); } TEST(task_unittest, WFFileIOTask) { srand(time(NULL)); std::mutex mutex; std::condition_variable cond; bool done = false; std::string file_path = "./" + std::to_string(time(NULL)) + "__" + std::to_string(rand() % 4096); int fd = open(file_path.c_str(), O_RDWR | O_CREAT, 0644); EXPECT_TRUE(fd > 0); char writebuf[] = "testtest"; char readbuf[16]; auto *write = WFTaskFactory::create_pwrite_task(fd, writebuf, 8, 80, [fd](WFFileIOTask *task) { auto state = task->get_state(); EXPECT_EQ(state, WFT_STATE_SUCCESS); if (state == WFT_STATE_SUCCESS) { auto *args = task->get_args(); EXPECT_EQ(args->fd, fd); EXPECT_EQ(args->count, 8); EXPECT_EQ(args->offset, 80); EXPECT_TRUE(strncmp("testtest", (char *)args->buf, 8) == 0); } }); auto *read = WFTaskFactory::create_pread_task(fd, readbuf, 8, 80, [fd](WFFileIOTask *task) { auto state = task->get_state(); EXPECT_EQ(state, WFT_STATE_SUCCESS); if (state == WFT_STATE_SUCCESS) { auto *args = task->get_args(); EXPECT_EQ(args->fd, fd); EXPECT_EQ(args->count, 8); EXPECT_EQ(args->offset, 80); EXPECT_TRUE(strncmp("testtest", (char *)args->buf, 8) == 0); } }); auto *series = Workflow::create_series_work(write, [&mutex, &cond, &done](const SeriesWork *series) { mutex.lock(); done = true; mutex.unlock(); cond.notify_one(); }); series->push_back(read); series->start(); std::unique_lock lock(mutex); while (!done) cond.wait(lock); lock.unlock(); close(fd); remove(file_path.c_str()); } TEST(task_unittest, WFFilePathIOTask) { srand(time(NULL)); std::mutex mutex; std::condition_variable cond; bool done = false; std::string file_path = "./" + std::to_string(time(NULL)) + "__" + std::to_string(rand() % 4096); char writebuf[] = "testtest"; char readbuf[16]; auto *write = WFTaskFactory::create_pwrite_task(file_path, writebuf, 8, 80, [](WFFileIOTask *task) { auto state = task->get_state(); EXPECT_EQ(state, WFT_STATE_SUCCESS); if (state == WFT_STATE_SUCCESS) { auto *args = task->get_args(); EXPECT_EQ(args->count, 8); EXPECT_EQ(args->offset, 80); EXPECT_TRUE(strncmp("testtest", (char *)args->buf, 8) == 0); } }); auto *read = WFTaskFactory::create_pread_task(file_path, readbuf, 8, 80, [](WFFileIOTask *task) { auto state = task->get_state(); EXPECT_EQ(state, WFT_STATE_SUCCESS); if (state == WFT_STATE_SUCCESS) { auto *args = task->get_args(); EXPECT_EQ(args->count, 8); EXPECT_EQ(args->offset, 80); EXPECT_TRUE(strncmp("testtest", (char *)args->buf, 8) == 0); } }); auto *series = Workflow::create_series_work(write, [&mutex, &cond, &done](const SeriesWork *series) { mutex.lock(); done = true; mutex.unlock(); cond.notify_one(); }); series->push_back(read); series->start(); std::unique_lock lock(mutex); while (!done) cond.wait(lock); lock.unlock(); remove(file_path.c_str()); } workflow-0.11.8/test/upstream_unittest.cc000066400000000000000000000331101476003635400205460ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Li Yingxin (liyingxin@sogou-inc.com) */ #include #include "workflow/UpstreamManager.h" #include "workflow/WFHttpServer.h" #include "workflow/WFTaskFactory.h" #include "workflow/WFFacilities.h" #include "workflow/UpstreamPolicies.h" #define REDIRECT_MAX 2 #define RETRY_MAX 2 #define MTTR 2 #define MAX_FAILS 200 static void __http_process(WFHttpTask *task, const char *name) { auto *resp = task->get_resp(); resp->add_header_pair("Content-Type", "text/plain"); resp->append_output_body_nocopy(name, strlen(name)); } WFHttpServer http_server1(std::bind(&__http_process, std::placeholders::_1, "server1")); WFHttpServer http_server2(std::bind(&__http_process, std::placeholders::_1, "server2")); WFHttpServer http_server3(std::bind(&__http_process, std::placeholders::_1, "server3")); void register_upstream_hosts() { UpstreamManager::upstream_create_weighted_random("weighted.random", false); AddressParams address_params = ADDRESS_PARAMS_DEFAULT; address_params.weight = 1000; UpstreamManager::upstream_add_server("weighted.random", "127.0.0.1:8001", &address_params); address_params.weight = 1; UpstreamManager::upstream_add_server("weighted.random", "127.0.0.1:8002", &address_params); UpstreamManager::upstream_create_consistent_hash( "hash", [](const char *path, const char *query, const char *fragment) -> unsigned int { return 4250947057; // test skip from the end to the begin, hit 8002 }); UpstreamManager::upstream_add_server("hash", "127.0.0.1:8001"); UpstreamManager::upstream_add_server("hash", "127.0.0.1:8001"); UpstreamManager::upstream_add_server("hash", "127.0.0.1:8002"); UpstreamManager::upstream_create_manual( "manual", [](const char *path, const char *query, const char *fragment) -> unsigned int { return 0; }, true, [](const char *path, const char *query, const char *fragment) -> unsigned int { return 511702306; // test skip the non-alive server }); UpstreamManager::upstream_add_server("manual", "127.0.0.1:8001"); UpstreamManager::upstream_add_server("manual", "127.0.0.1:8002"); UpstreamManager::upstream_create_round_robin("round.robin", true); UpstreamManager::upstream_add_server("round.robin", "127.0.0.1:8001"); UpstreamManager::upstream_add_server("round.robin", "127.0.0.1:8002"); UpstreamManager::upstream_create_manual( "try_another", [](const char *path, const char *query, const char *fragment) -> unsigned int { return 0; }, false, nullptr); UpstreamManager::upstream_add_server("try_another", "127.0.0.1:8001"); UpstreamManager::upstream_add_server("try_another", "127.0.0.1:8002"); UpstreamManager::upstream_create_weighted_random("test_tracing", true); address_params.weight = 1000; UpstreamManager::upstream_add_server("test_tracing", "127.0.0.1:8001", &address_params); address_params.weight = 1; UpstreamManager::upstream_add_server("test_tracing", "127.0.0.1:8002", &address_params); address_params.weight = 1000; UpstreamManager::upstream_add_server("test_tracing", "127.0.0.1:8003", &address_params); } void basic_callback(WFHttpTask *task, std::string& message) { int state = task->get_state(); EXPECT_EQ(state, WFT_STATE_SUCCESS); if (state == WFT_STATE_SUCCESS && message.compare("")) { const void *body; size_t body_len; task->get_resp()->get_parsed_body(&body, &body_len); std::string buffer((char *)body, body_len); EXPECT_EQ(buffer, message); } WFFacilities::WaitGroup *wait_group = (WFFacilities::WaitGroup *)task->user_data; wait_group->done(); } TEST(upstream_unittest, BasicPolicy) { WFFacilities::WaitGroup wait_group(5); WFHttpTask *task1; WFHttpTask *task2; char url[4][30] = {"http://weighted.random", "http://manual", "http://hash", "http://round.robin"}; http_callback_t cb1 = std::bind(basic_callback, std::placeholders::_1, std::string("server1")); for (int i = 0; i < 2; i++) { task1 = WFTaskFactory::create_http_task(url[i], REDIRECT_MAX, RETRY_MAX, cb1); task1->user_data = &wait_group; task1->start(); } http_callback_t cb2 = std::bind(basic_callback, std::placeholders::_1, std::string("server2")); task2 = WFTaskFactory::create_http_task(url[2], REDIRECT_MAX, RETRY_MAX, cb2); task2->user_data = &wait_group; task2->start(); task1 = WFTaskFactory::create_http_task(url[3], REDIRECT_MAX, RETRY_MAX, cb1); task1->user_data = &wait_group; task2 = WFTaskFactory::create_http_task(url[3], REDIRECT_MAX, RETRY_MAX, cb2); task2->user_data = &wait_group; SeriesWork *series = Workflow::create_series_work(task1, nullptr); series->push_back(task2); series->start(); wait_group.wait(); } TEST(upstream_unittest, EnableAndDisable) { WFFacilities::WaitGroup wait_group(1); UpstreamManager::upstream_disable_server("weighted.random", "127.0.0.1:8001"); std::string url = "http://weighted.random"; WFHttpTask *task = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX, [&wait_group, &url](WFHttpTask *task){ int state = task->get_state(); EXPECT_EQ(state, WFT_STATE_TASK_ERROR); EXPECT_EQ(task->get_error(), WFT_ERR_UPSTREAM_UNAVAILABLE); UpstreamManager::upstream_enable_server("weighted.random", "127.0.0.1:8001"); auto *task2 = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX, std::bind(basic_callback, std::placeholders::_1, std::string("server1"))); task2->user_data = &wait_group; series_of(task)->push_back(task2); }); task->user_data = &wait_group; task->start(); wait_group.wait(); } TEST(upstream_unittest, AddAndRemove) { WFFacilities::WaitGroup wait_group(2); WFHttpTask *task; SeriesWork *series; protocol::HttpRequest *req; int batch = MAX_FAILS + 50; std::string url = "http://add_and_remove"; std::string name = "add_and_remove"; UPSWeightedRandomPolicy test_policy(false); AddressParams address_params = ADDRESS_PARAMS_DEFAULT; address_params.weight = 1000; test_policy.add_server("127.0.0.1:8001", &address_params); address_params.weight = 1; test_policy.add_server("127.0.0.1:8002", &address_params); auto *ns = WFGlobal::get_name_service(); EXPECT_EQ(ns->add_policy(name.c_str(), &test_policy), 0); UpstreamManager::upstream_remove_server(name, "127.0.0.1:8001"); task = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX, std::bind(basic_callback, std::placeholders::_1, std::string("server2"))); task->user_data = &wait_group; task->start(); //test remove fused server address_params.weight = 1000; test_policy.add_server("127.0.0.1:8001", &address_params); http_server1.stop(); fprintf(stderr, "server 1 stopped start %d tasks to fuse it\n", batch); ParallelWork *pwork = Workflow::create_parallel_work( [&wait_group, &name, &url](const ParallelWork *pwork) { fprintf(stderr, "parallel finished and remove server1\n"); UpstreamManager::upstream_remove_server(name, "127.0.0.1:8001"); auto *task = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX, std::bind(basic_callback, std::placeholders::_1, std::string("server2"))); task->user_data = &wait_group; series_of(pwork)->push_back(task); }); for (int i = 0; i < batch; i++) { task = WFTaskFactory::create_http_task(url, 0, 0, nullptr); req = task->get_req(); req->add_header_pair("Connection", "keep-alive"); series = Workflow::create_series_work(task, nullptr); pwork->add_series(series); } pwork->start(); wait_group.wait(); EXPECT_TRUE(http_server1.start("127.0.0.1", 8001) == 0) << "http server start failed"; ns->del_policy(name.c_str()); } TEST(upstream_unittest, FuseAndRecover) { WFFacilities::WaitGroup wait_group(1); WFHttpTask *task; SeriesWork *series; protocol::HttpRequest *req; std::string url = "http://test_policy"; int batch = MAX_FAILS + 50; int timeout = (MTTR + 1) * 1000000; UPSWeightedRandomPolicy test_policy(false); test_policy.set_mttr_seconds(MTTR); AddressParams address_params = ADDRESS_PARAMS_DEFAULT; address_params.weight = 1000; test_policy.add_server("127.0.0.1:8001", &address_params); address_params.weight = 1; test_policy.add_server("127.0.0.1:8002", &address_params); auto *ns = WFGlobal::get_name_service(); EXPECT_EQ(ns->add_policy("test_policy", &test_policy), 0); http_server1.stop(); fprintf(stderr, "server 1 stopped start %d tasks to fuse it\n", batch); ParallelWork *pwork = Workflow::create_parallel_work( [](const ParallelWork *pwork) { fprintf(stderr, "parallel finished\n"); }); for (int i = 0; i < batch; i++) { task = WFTaskFactory::create_http_task(url, 0, 0, nullptr); req = task->get_req(); req->add_header_pair("Connection", "keep-alive"); series = Workflow::create_series_work(task, nullptr); pwork->add_series(series); } series = Workflow::create_series_work(pwork, nullptr); WFTimerTask *timer = WFTaskFactory::create_timer_task(timeout, [](WFTimerTask *task) { fprintf(stderr, "timer_finished and start server1\n"); EXPECT_TRUE(http_server1.start("127.0.0.1", 8001) == 0) << "http server start failed"; }); series->push_back(timer); task = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX, std::bind(basic_callback, std::placeholders::_1, std::string("server1"))); task->user_data = &wait_group; series->push_back(task); series->start(); wait_group.wait(); ns->del_policy("test_policy"); } TEST(upstream_unittest, TryAnother) { WFFacilities::WaitGroup wait_group(3); UpstreamManager::upstream_disable_server("manual", "127.0.0.1:8001"); UpstreamManager::upstream_disable_server("round.robin", "127.0.0.1:8001"); UpstreamManager::upstream_disable_server("try_another", "127.0.0.1:8001"); http_callback_t cb2 = std::bind(basic_callback, std::placeholders::_1, std::string("server2")); WFHttpTask *task = WFTaskFactory::create_http_task("http://manual", REDIRECT_MAX, RETRY_MAX, cb2); task->user_data = &wait_group; task->start(); // this->cur_idx == 1. Will skip 8001 and try 8002. task = WFTaskFactory::create_http_task("http://round.robin", REDIRECT_MAX, RETRY_MAX, cb2); task->user_data = &wait_group; task->start(); task = WFTaskFactory::create_http_task("http://try_another", REDIRECT_MAX, RETRY_MAX, [&wait_group](WFHttpTask *task){ int state = task->get_state(); EXPECT_EQ(state, WFT_STATE_TASK_ERROR); EXPECT_EQ(task->get_error(), WFT_ERR_UPSTREAM_UNAVAILABLE); wait_group.done(); }); task->start(); wait_group.wait(); UpstreamManager::upstream_enable_server("manual", "127.0.0.1:8001"); UpstreamManager::upstream_enable_server("round.robin", "127.0.0.1:8001"); UpstreamManager::upstream_enable_server("try_another", "127.0.0.1:8001"); } TEST(upstream_unittest, Tracing) { WFFacilities::WaitGroup wait_group(2); http_server1.stop(); // test first_strategy() WFHttpTask *task = WFTaskFactory::create_http_task( "http://weighted.random", REDIRECT_MAX, RETRY_MAX, std::bind(basic_callback, std::placeholders::_1, std::string("server2"))); task->user_data = &wait_group; task->start(); // test another_strategy() UpstreamManager::upstream_disable_server("test_tracing", "127.0.0.1:8003"); WFHttpTask *task2 = WFTaskFactory::create_http_task( "http://test_tracing", REDIRECT_MAX, RETRY_MAX, std::bind(basic_callback, std::placeholders::_1, std::string("server2"))); task2->user_data = &wait_group; task2->start(); wait_group.wait(); EXPECT_TRUE(http_server1.start("127.0.0.1", 8001) == 0) << "http server start failed"; UpstreamManager::upstream_enable_server("test_tracing", "127.0.0.1:8003"); } TEST(upstream_unittest, RoundRobin) { WFFacilities::WaitGroup wait_group(1); // this->cur_idx = 0. When 8002 is removed, we will try 8001. UpstreamManager::upstream_remove_server("round.robin", "127.0.0.1:8002"); WFHttpTask *task = WFTaskFactory::create_http_task("http://round.robin", REDIRECT_MAX, RETRY_MAX, std::bind(basic_callback, std::placeholders::_1, std::string("server1"))); task->user_data = &wait_group; task->start(); wait_group.wait(); UpstreamManager::upstream_add_server("round.robin", "127.0.0.1:8002"); } int main(int argc, char* argv[]) { ::testing::InitGoogleTest(&argc, argv); register_upstream_hosts(); EXPECT_TRUE(http_server1.start("127.0.0.1", 8001) == 0) << "http server start failed"; EXPECT_TRUE(http_server2.start("127.0.0.1", 8002) == 0) << "http server start failed"; EXPECT_TRUE(http_server3.start("127.0.0.1", 8003) == 0) << "http server start failed"; EXPECT_EQ(RUN_ALL_TESTS(), 0); EXPECT_EQ(UpstreamManager::upstream_delete("try_another"), 0); EXPECT_EQ(UpstreamManager::upstream_delete("try_another"), -1); http_server1.stop(); http_server2.stop(); http_server3.stop(); return 0; } workflow-0.11.8/test/uriparser_unittest.cc000066400000000000000000000214641476003635400207330ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Wang Zhulei (wangzhulei@sogou-inc.com) */ #include #include "workflow/URIParser.h" TEST(uriparser_unittest, parse) { ParsedURI uri; EXPECT_EQ(URIParser::parse("https://john.doe:pass@www.example.com:123/forum/questions/?tag=networking&order=newest#top", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "https"), 0); EXPECT_EQ(strcmp(uri.userinfo, "john.doe:pass"), 0); EXPECT_EQ(strcmp(uri.host, "www.example.com"), 0); EXPECT_EQ(strcmp(uri.port, "123"), 0); EXPECT_EQ(strcmp(uri.path, "/forum/questions/"), 0); EXPECT_EQ(strcmp(uri.query, "tag=networking&order=newest"), 0); EXPECT_EQ(strcmp(uri.fragment, "top"), 0); EXPECT_EQ(URIParser::parse("https://john.doe@www.example.com:123/forum/questions/?tag=networking&order=newest#top", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "https"), 0); EXPECT_EQ(strcmp(uri.userinfo, "john.doe"), 0); EXPECT_EQ(strcmp(uri.host, "www.example.com"), 0); EXPECT_EQ(strcmp(uri.port, "123"), 0); EXPECT_EQ(strcmp(uri.path, "/forum/questions/"), 0); EXPECT_EQ(strcmp(uri.query, "tag=networking&order=newest"), 0); EXPECT_EQ(strcmp(uri.fragment, "top"), 0); EXPECT_EQ(URIParser::parse("ldap://[2001:db8::7]/c=GB?objectClass?one", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "ldap"), 0); EXPECT_EQ(uri.userinfo, nullptr); EXPECT_EQ(strcmp(uri.host, "2001:db8::7"), 0); EXPECT_EQ(uri.port, nullptr); EXPECT_EQ(strcmp(uri.path, "/c=GB"), 0); EXPECT_EQ(strcmp(uri.query, "objectClass?one"), 0); EXPECT_EQ(uri.fragment, nullptr); EXPECT_EQ(URIParser::parse("ldap://user@[2001:db8::7]/c=GB?objectClass?one", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "ldap"), 0); EXPECT_EQ(strcmp(uri.userinfo, "user"), 0); EXPECT_EQ(strcmp(uri.host, "2001:db8::7"), 0); EXPECT_EQ(uri.port, nullptr); EXPECT_EQ(strcmp(uri.path, "/c=GB"), 0); EXPECT_EQ(strcmp(uri.query, "objectClass?one"), 0); EXPECT_EQ(uri.fragment, nullptr); EXPECT_EQ(URIParser::parse("ldap://user@[2001:db8::7]:12345/c=GB?objectClass?one", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "ldap"), 0); EXPECT_EQ(strcmp(uri.userinfo, "user"), 0); EXPECT_EQ(strcmp(uri.host, "2001:db8::7"), 0); EXPECT_EQ(strcmp(uri.port, "12345"), 0); EXPECT_EQ(strcmp(uri.path, "/c=GB"), 0); EXPECT_EQ(strcmp(uri.query, "objectClass?one"), 0); EXPECT_EQ(uri.fragment, nullptr); EXPECT_EQ(URIParser::parse("ldap://[2001:db8::7]:12345/c=GB?objectClass?one", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "ldap"), 0); EXPECT_EQ(uri.userinfo, nullptr); EXPECT_EQ(strcmp(uri.host, "2001:db8::7"), 0); EXPECT_EQ(strcmp(uri.port, "12345"), 0); EXPECT_EQ(strcmp(uri.path, "/c=GB"), 0); EXPECT_EQ(strcmp(uri.query, "objectClass?one"), 0); EXPECT_EQ(uri.fragment, nullptr); EXPECT_EQ(URIParser::parse("mailto:John.Doe@example.com", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "mailto"), 0); EXPECT_EQ(uri.userinfo, nullptr); EXPECT_EQ(uri.host, nullptr); EXPECT_EQ(uri.port, nullptr); EXPECT_EQ(strcmp(uri.path, "John.Doe@example.com"), 0); EXPECT_EQ(uri.query, nullptr); EXPECT_EQ(uri.fragment, nullptr); EXPECT_EQ(URIParser::parse("news:comp.infosystems.www.servers.unix", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "news"), 0); EXPECT_EQ(uri.userinfo, nullptr); EXPECT_EQ(uri.host, nullptr); EXPECT_EQ(uri.port, nullptr); EXPECT_EQ(strcmp(uri.path, "comp.infosystems.www.servers.unix"), 0); EXPECT_EQ(uri.query, nullptr); EXPECT_EQ(uri.fragment, nullptr); EXPECT_EQ(URIParser::parse("tel:+1-816-555-1212", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "tel"), 0); EXPECT_EQ(uri.userinfo, nullptr); EXPECT_EQ(uri.host, nullptr); EXPECT_EQ(uri.port, nullptr); EXPECT_EQ(strcmp(uri.path, "+1-816-555-1212"), 0); EXPECT_EQ(uri.query, nullptr); EXPECT_EQ(uri.fragment, nullptr); EXPECT_EQ(URIParser::parse("telnet://192.0.2.16:80/", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "telnet"), 0); EXPECT_EQ(uri.userinfo, nullptr); EXPECT_EQ(strcmp(uri.host, "192.0.2.16"), 0); EXPECT_EQ(strcmp(uri.port, "80"), 0); EXPECT_EQ(strcmp(uri.path, "/"), 0); EXPECT_EQ(uri.query, nullptr); EXPECT_EQ(uri.fragment, nullptr); EXPECT_EQ(URIParser::parse("urn:oasis:names:specification:docbook:dtd:xml:4.1.2", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "urn"), 0); EXPECT_EQ(uri.userinfo, nullptr); EXPECT_EQ(uri.host, nullptr); EXPECT_EQ(uri.port, nullptr); EXPECT_EQ(strcmp(uri.path, "oasis:names:specification:docbook:dtd:xml:4.1.2"), 0); EXPECT_EQ(uri.query, nullptr); EXPECT_EQ(uri.fragment, nullptr); EXPECT_EQ(URIParser::parse("https://www.example.com:123/forum/questions/?tag=networking&order=newest#top", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "https"), 0); EXPECT_EQ(uri.userinfo, nullptr); EXPECT_EQ(strcmp(uri.host, "www.example.com"), 0); EXPECT_EQ(strcmp(uri.port, "123"), 0); EXPECT_EQ(strcmp(uri.path, "/forum/questions/"), 0); EXPECT_EQ(strcmp(uri.query, "tag=networking&order=newest"), 0); EXPECT_EQ(strcmp(uri.fragment, "top"), 0); EXPECT_EQ(URIParser::parse("https://john.doe@www.example.com/forum/questions/?tag=networking&order=newest#top", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "https"), 0); EXPECT_EQ(strcmp(uri.userinfo, "john.doe"), 0); EXPECT_EQ(strcmp(uri.host, "www.example.com"), 0); EXPECT_EQ(uri.port, nullptr); EXPECT_EQ(strcmp(uri.path, "/forum/questions/"), 0); EXPECT_EQ(strcmp(uri.query, "tag=networking&order=newest"), 0); EXPECT_EQ(strcmp(uri.fragment, "top"), 0); EXPECT_EQ(URIParser::parse("foo:/index.html", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "foo"), 0); EXPECT_EQ(uri.userinfo, nullptr); EXPECT_EQ(uri.host, nullptr); EXPECT_EQ(uri.port, nullptr); EXPECT_EQ(strcmp(uri.path, "/index.html"), 0); EXPECT_EQ(uri.query, nullptr); EXPECT_EQ(uri.fragment, nullptr); EXPECT_EQ(URIParser::parse("http://www.test.cn/subject/ttt/index.html?abc-def-jki-lm-rstuvwxyz", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "http"), 0); EXPECT_EQ(uri.userinfo, nullptr); EXPECT_EQ(strcmp(uri.host, "www.test.cn"), 0); EXPECT_EQ(uri.port, nullptr); EXPECT_EQ(strcmp(uri.path, "/subject/ttt/index.html"), 0); EXPECT_EQ(strcmp(uri.query, "abc-def-jki-lm-rstuvwxyz"), 0); EXPECT_EQ(uri.fragment, nullptr); EXPECT_EQ(URIParser::parse("http://sg.test1.com/zt/zz/#IJ", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "http"), 0); EXPECT_EQ(uri.userinfo, nullptr); EXPECT_EQ(strcmp(uri.host, "sg.test1.com"), 0); EXPECT_EQ(uri.port, nullptr); EXPECT_EQ(strcmp(uri.path, "/zt/zz/"), 0); EXPECT_EQ(uri.query, nullptr); EXPECT_EQ(strcmp(uri.fragment, "IJ"), 0); EXPECT_EQ(URIParser::parse("http://www.test2.com?sg_vid=R_3qHh9H471Ry8OtW5J9R10vc_QR6EQqgA6HHLO6666666qe0Co66666666", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "http"), 0); EXPECT_EQ(uri.userinfo, nullptr); EXPECT_EQ(strcmp(uri.host, "www.test2.com"), 0); EXPECT_EQ(uri.port, nullptr); EXPECT_EQ(uri.path, nullptr); EXPECT_EQ(strcmp(uri.query, "sg_vid=R_3qHh9H471Ry8OtW5J9R10vc_QR6EQqgA6HHLO6666666qe0Co66666666"), 0); EXPECT_EQ(uri.fragment, nullptr); EXPECT_EQ(URIParser::parse("https://sgsares.test3.com/ttts/IJ_4115.apk", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "https"), 0); EXPECT_EQ(uri.userinfo, nullptr); EXPECT_EQ(strcmp(uri.host, "sgsares.test3.com"), 0); EXPECT_EQ(uri.port, nullptr); EXPECT_EQ(strcmp(uri.path, "/ttts/IJ_4115.apk"), 0); EXPECT_EQ(uri.query, nullptr); EXPECT_EQ(uri.fragment, nullptr); EXPECT_EQ(URIParser::parse("http://viptest.test5.com:8484?sg_vid=Rucnk5BKG81RcIVk7XySNhQtBODR6mKXA06PpWA66666663MTAfR6666666", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "http"), 0); EXPECT_EQ(uri.userinfo, nullptr); EXPECT_EQ(strcmp(uri.host, "viptest.test5.com"), 0); EXPECT_EQ(strcmp(uri.port, "8484"), 0); EXPECT_EQ(uri.path, nullptr); EXPECT_EQ(strcmp(uri.query, "sg_vid=Rucnk5BKG81RcIVk7XySNhQtBODR6mKXA06PpWA66666663MTAfR6666666"), 0); EXPECT_EQ(uri.fragment, nullptr); EXPECT_EQ(URIParser::parse("http://viptest1.test6.com:84/abc#frag", uri), 0); EXPECT_EQ(strcmp(uri.scheme, "http"), 0); EXPECT_EQ(uri.userinfo, nullptr); EXPECT_EQ(strcmp(uri.host, "viptest1.test6.com"), 0); EXPECT_EQ(strcmp(uri.port, "84"), 0); EXPECT_EQ(strcmp(uri.path, "/abc"), 0); EXPECT_EQ(uri.query, nullptr); EXPECT_EQ(strcmp(uri.fragment, "frag"), 0); EXPECT_EQ(URIParser::parse("http://www.sogou.com", uri), 0); EXPECT_EQ(uri.path, nullptr); EXPECT_EQ(uri.query, nullptr); EXPECT_EQ(uri.fragment, nullptr); EXPECT_EQ(URIParser::parse("http://www.sogou.com?aaa/=bbb", uri), 0); EXPECT_EQ(uri.path, nullptr); EXPECT_EQ(strcmp(uri.query, "aaa/=bbb"), 0); } workflow-0.11.8/test/xmake.lua000066400000000000000000000021541476003635400162540ustar00rootroot00000000000000set_group("test") set_default(false) add_requires("gtest") add_deps("workflow") add_packages("gtest") add_links("gtest_main") if not is_plat("macosx") then add_ldflags("-lrt") end function all_tests() local res = {} for _, x in ipairs(os.files("**.cc")) do local item = {} local s = path.filename(x) if ((s == "upstream_unittest.cc" and not has_config("upstream")) or (s == "redis_unittest.cc" and not has_config("redis")) or (s == "mysql_unittest.cc" and not has_config("mysql"))) then else table.insert(item, s:sub(1, #s - 3)) -- target table.insert(item, path.relative(x, ".")) -- source table.insert(res, item) end end return res end for _, test in ipairs(all_tests()) do target(test[1]) set_kind("binary") add_files(test[2]) if has_config("memcheck") then on_run(function (target) local argv = {} table.insert(argv, target:targetfile()) table.insert(argv, "--leak-check=full") os.execv("valgrind", argv) end) end end workflow-0.11.8/tutorial/000077500000000000000000000000001476003635400153265ustar00rootroot00000000000000workflow-0.11.8/tutorial/CMakeLists.txt000066400000000000000000000100601476003635400200630ustar00rootroot00000000000000cmake_minimum_required(VERSION 3.6) set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "build type") project(tutorial LANGUAGES C CXX ) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}) if(ANDROID) link_directories(${OPENSSL_LINK_DIR}) else() find_library(LIBRT rt) find_package(OpenSSL REQUIRED) endif() find_package(workflow REQUIRED CONFIG HINTS ..) include_directories(${OPENSSL_INCLUDE_DIR} ${WORKFLOW_INCLUDE_DIR}) link_directories(${WORKFLOW_LIB_DIR}) if (KAFKA STREQUAL "y") find_path(SNAPPY_INCLUDE_PATH NAMES snappy.h) include_directories(${SNAPPY_INCLUDE_PATH}) endif () if (WIN32) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /MP /wd4200") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP /wd4200 /std:c++14") else () set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -fPIC -pipe -std=gnu90") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -fPIC -pipe -std=c++11 -fno-exceptions") if (APPLE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated-declarations") endif() endif () set(TUTORIAL_LIST tutorial-00-helloworld tutorial-01-wget tutorial-04-http_echo_server tutorial-05-http_proxy tutorial-06-parallel_wget tutorial-07-sort_task tutorial-08-matrix_multiply tutorial-09-http_file_server tutorial-11-graph_task tutorial-15-name_service tutorial-17-dns_cli tutorial-19-dns_server ) if (APPLE) set(WORKFLOW_LIB workflow pthread OpenSSL::SSL OpenSSL::Crypto) elseif (ANDROID) set(WORKFLOW_LIB workflow ssl crypto c) else () set(WORKFLOW_LIB workflow pthread OpenSSL::SSL OpenSSL::Crypto ${LIBRT}) endif () foreach(src ${TUTORIAL_LIST}) string(REPLACE "-" ";" arr ${src}) list(GET arr -1 bin_name) add_executable(${bin_name} ${src}.cc) target_link_libraries(${bin_name} ${WORKFLOW_LIB}) endforeach() if (NOT REDIS STREQUAL "n") set(TUTORIAL_LIST tutorial-02-redis_cli tutorial-03-wget_to_redis tutorial-18-redis_subscriber ) foreach(src ${TUTORIAL_LIST}) string(REPLACE "-" ";" arr ${src}) list(GET arr -1 bin_name) add_executable(${bin_name} ${src}.cc) target_link_libraries(${bin_name} ${WORKFLOW_LIB}) endforeach() endif() if (NOT MYSQL STREQUAL "n") set(TUTORIAL_LIST tutorial-12-mysql_cli ) foreach(src ${TUTORIAL_LIST}) string(REPLACE "-" ";" arr ${src}) list(GET arr -1 bin_name) add_executable(${bin_name} ${src}.cc) target_link_libraries(${bin_name} ${WORKFLOW_LIB}) endforeach() endif() if (NOT CONSUL STREQUAL "n") set(TUTORIAL_LIST tutorial-14-consul_cli ) foreach(src ${TUTORIAL_LIST}) string(REPLACE "-" ";" arr ${src}) list(GET arr -1 bin_name) add_executable(${bin_name} ${src}.cc) target_link_libraries(${bin_name} ${WORKFLOW_LIB}) endforeach() endif() if (KAFKA STREQUAL "y") add_executable("kafka_cli" "tutorial-13-kafka_cli.cc") target_link_libraries("kafka_cli" wfkafka ${WORKFLOW_LIB} z snappy lz4 zstd) endif () set(DIR10 tutorial-10-user_defined_protocol) add_executable(server ${DIR10}/server.cc ${DIR10}/message.cc) add_executable(client ${DIR10}/client.cc ${DIR10}/message.cc) add_executable(server-uds ${DIR10}/server-uds.cc ${DIR10}/message.cc) add_executable(client-uds ${DIR10}/client-uds.cc ${DIR10}/message.cc) target_link_libraries(server ${WORKFLOW_LIB}) target_link_libraries(client ${WORKFLOW_LIB}) target_link_libraries(server-uds ${WORKFLOW_LIB}) target_link_libraries(client-uds ${WORKFLOW_LIB}) set_target_properties(server PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/${DIR10}) set_target_properties(client PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/${DIR10}) set_target_properties(server-uds PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/${DIR10}) set_target_properties(client-uds PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/${DIR10}) set(DIR16 tutorial-16-graceful_restart) add_executable(bootstrap ${DIR16}/bootstrap.c) add_executable(bootstrap_server ${DIR16}/server.cc) target_link_libraries(bootstrap_server ${WORKFLOW_LIB}) set_target_properties(bootstrap PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/${DIR16}) set_target_properties(bootstrap_server PROPERTIES OUTPUT_NAME "server" RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/${DIR16}) workflow-0.11.8/tutorial/GNUmakefile000066400000000000000000000017131476003635400174020ustar00rootroot00000000000000ROOT_DIR := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) ALL_TARGETS := all clean MAKE_FILE := Makefile DEFAULT_BUILD_DIR := build.cmake BUILD_DIR := $(shell if [ -f $(MAKE_FILE) ]; then echo "."; else echo $(DEFAULT_BUILD_DIR); fi) CMAKE3 := $(shell if which cmake3>/dev/null ; then echo cmake3; else echo cmake; fi;) .PHONY: $(ALL_TARGETS) all: mkdir -p $(BUILD_DIR) rm -rf $(DEFAULT_BUILD_DIR)/CMakeCache.txt ifeq ($(DEBUG),y) cd $(BUILD_DIR) && $(CMAKE3) -D CMAKE_BUILD_TYPE=Debug -D CONSUL=$(CONSUL) -D KAFKA=$(KAFKA) -D MYSQL=$(MYSQL) -D REDIS=$(REDIS) $(ROOT_DIR) else cd $(BUILD_DIR) && $(CMAKE3) -D CONSUL=$(CONSUL) -D KAFKA=$(KAFKA) -D MYSQL=$(MYSQL) -D REDIS=$(REDIS) $(ROOT_DIR) endif make -C $(BUILD_DIR) -f Makefile clean: ifeq ($(MAKE_FILE), $(wildcard $(MAKE_FILE))) -make -f Makefile clean else ifeq ($(DEFAULT_BUILD_DIR), $(wildcard $(DEFAULT_BUILD_DIR))) -make -C $(DEFAULT_BUILD_DIR) clean endif rm -rf $(DEFAULT_BUILD_DIR) workflow-0.11.8/tutorial/tutorial-00-helloworld.cc000066400000000000000000000005421476003635400220670ustar00rootroot00000000000000#include #include "workflow/WFHttpServer.h" int main() { WFHttpServer server([](WFHttpTask *task) { task->get_resp()->append_output_body("Hello World!"); }); if (server.start(8888) == 0) { // start server on port 8888 getchar(); // press "Enter" to end. server.stop(); } return 0; } workflow-0.11.8/tutorial/tutorial-01-wget.cc000066400000000000000000000064201476003635400206640ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ #include #include #include #include #include #include #include "workflow/HttpMessage.h" #include "workflow/HttpUtil.h" #include "workflow/WFTaskFactory.h" #include "workflow/WFFacilities.h" #define REDIRECT_MAX 5 #define RETRY_MAX 2 void wget_callback(WFHttpTask *task) { protocol::HttpRequest *req = task->get_req(); protocol::HttpResponse *resp = task->get_resp(); int state = task->get_state(); int error = task->get_error(); switch (state) { case WFT_STATE_SYS_ERROR: fprintf(stderr, "system error: %s\n", strerror(error)); break; case WFT_STATE_DNS_ERROR: fprintf(stderr, "DNS error: %s\n", gai_strerror(error)); break; case WFT_STATE_SSL_ERROR: fprintf(stderr, "SSL error: %d\n", error); break; case WFT_STATE_TASK_ERROR: fprintf(stderr, "Task error: %d\n", error); break; case WFT_STATE_SUCCESS: break; } if (state != WFT_STATE_SUCCESS) { fprintf(stderr, "Failed. Press Ctrl-C to exit.\n"); return; } std::string name; std::string value; /* Print request. */ fprintf(stderr, "%s %s %s\r\n", req->get_method(), req->get_http_version(), req->get_request_uri()); protocol::HttpHeaderCursor req_cursor(req); while (req_cursor.next(name, value)) fprintf(stderr, "%s: %s\r\n", name.c_str(), value.c_str()); fprintf(stderr, "\r\n"); /* Print response header. */ fprintf(stderr, "%s %s %s\r\n", resp->get_http_version(), resp->get_status_code(), resp->get_reason_phrase()); protocol::HttpHeaderCursor resp_cursor(resp); while (resp_cursor.next(name, value)) fprintf(stderr, "%s: %s\r\n", name.c_str(), value.c_str()); fprintf(stderr, "\r\n"); /* Print response body. */ const void *body; size_t body_len; resp->get_parsed_body(&body, &body_len); fwrite(body, 1, body_len, stdout); fflush(stdout); fprintf(stderr, "\nSuccess. Press Ctrl-C to exit.\n"); } static WFFacilities::WaitGroup wait_group(1); void sig_handler(int signo) { wait_group.done(); } int main(int argc, char *argv[]) { WFHttpTask *task; if (argc != 2) { fprintf(stderr, "USAGE: %s \n", argv[0]); exit(1); } signal(SIGINT, sig_handler); std::string url = argv[1]; if (strncasecmp(argv[1], "http://", 7) != 0 && strncasecmp(argv[1], "https://", 8) != 0) { url = "http://" + url; } task = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX, wget_callback); protocol::HttpRequest *req = task->get_req(); req->add_header_pair("Accept", "*/*"); req->add_header_pair("User-Agent", "Wget/1.14 (linux-gnu)"); req->add_header_pair("Connection", "close"); task->start(); wait_group.wait(); return 0; } workflow-0.11.8/tutorial/tutorial-02-redis_cli.cc000066400000000000000000000073251476003635400216610ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ #include #include #include #include #include #include #include "workflow/RedisMessage.h" #include "workflow/WFTaskFactory.h" #include "workflow/WFFacilities.h" #define RETRY_MAX 2 struct tutorial_task_data { std::string url; std::string key; }; void redis_callback(WFRedisTask *task) { protocol::RedisRequest *req = task->get_req(); protocol::RedisResponse *resp = task->get_resp(); int state = task->get_state(); int error = task->get_error(); protocol::RedisValue val; switch (state) { case WFT_STATE_SYS_ERROR: fprintf(stderr, "system error: %s\n", strerror(error)); break; case WFT_STATE_DNS_ERROR: fprintf(stderr, "DNS error: %s\n", gai_strerror(error)); break; case WFT_STATE_SSL_ERROR: fprintf(stderr, "SSL error: %d\n", error); break; case WFT_STATE_TASK_ERROR: fprintf(stderr, "Task error: %d\n", error); break; case WFT_STATE_SUCCESS: resp->get_result(val); if (val.is_error()) { fprintf(stderr, "%*s\n", (int)val.string_view()->size(), val.string_view()->c_str()); state = WFT_STATE_TASK_ERROR; } break; } if (state != WFT_STATE_SUCCESS) { fprintf(stderr, "Failed. Press Ctrl-C to exit.\n"); return; } std::string cmd; req->get_command(cmd); if (cmd == "SET") { tutorial_task_data *data = (tutorial_task_data *)task->user_data; WFRedisTask *next = WFTaskFactory::create_redis_task(data->url, RETRY_MAX, redis_callback); next->get_req()->set_request("GET", { data->key }); /* Push next task(GET task) to current series. */ series_of(task)->push_back(next); fprintf(stderr, "Redis SET request success. Trying to GET...\n"); } else /* if (cmd == "GET") */ { if (val.is_string()) { fprintf(stderr, "Redis GET success. value = %s\n", val.string_value().c_str()); } else { fprintf(stderr, "Error: Not a string value. \n"); } fprintf(stderr, "Finished. Press Ctrl-C to exit.\n"); } } static WFFacilities::WaitGroup wait_group(1); void sig_handler(int signo) { wait_group.done(); } int main(int argc, char *argv[]) { WFRedisTask *task; if (argc != 4) { fprintf(stderr, "USAGE: %s \n", argv[0]); exit(1); } signal(SIGINT, sig_handler); /* This struct only used in this tutorial. */ struct tutorial_task_data data; /* Redis URL format: redis://:password@host:port/dbnum examples: redis://127.0.0.1 redis://:12345@redis.sogou:6379/3 */ data.url = argv[1]; if (strncasecmp(argv[1], "redis://", 8) != 0 && strncasecmp(argv[1], "rediss://", 9) != 0) { data.url = "redis://" + data.url; } data.key = argv[2]; task = WFTaskFactory::create_redis_task(data.url, RETRY_MAX, redis_callback); protocol::RedisRequest *req = task->get_req(); req->set_request("SET", { data.key, argv[3] }); /* task->user_data is a public (void *), can store anything. */ task->user_data = &data; /* task->start() equel to: * Workflow::start_series_work(task, nullptr) or * Workflow::create_series_work(task, nullptr)->start() */ task->start(); wait_group.wait(); return 0; } workflow-0.11.8/tutorial/tutorial-03-wget_to_redis.cc000066400000000000000000000106341476003635400225600ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ /* Tuturial-03. Store wget result in redis: key=URL, value=Http Body*/ #include #include #include #include #include #include "workflow/HttpMessage.h" #include "workflow/HttpUtil.h" #include "workflow/RedisMessage.h" #include "workflow/Workflow.h" #include "workflow/WFTaskFactory.h" #include "workflow/WFFacilities.h" using namespace protocol; #define REDIRECT_MAX 5 #define RETRY_MAX 2 struct tutorial_series_context { std::string http_url; std::string redis_url; size_t body_len; bool success; }; void redis_callback(WFRedisTask *task) { int state = task->get_state(); tutorial_series_context *context = (tutorial_series_context *)series_of(task)->get_context(); RedisValue value; if (state == WFT_STATE_SUCCESS) { task->get_resp()->get_result(value); if (!value.is_error()) { fprintf(stderr, "redis SET success: key: %s, value size: %zu\n", context->http_url.c_str(), context->body_len); context->success = true; } else fprintf(stderr, "redis error reply! Need password?\n"); } else { fprintf(stderr, "redis SET error: state = %d, error = %d\n", state, task->get_error()); } } void http_callback(WFHttpTask *task) { HttpResponse *resp = task->get_resp(); int state = task->get_state(); int error = task->get_error(); if (state != WFT_STATE_SUCCESS) { fprintf(stderr, "http task error: state = %d, error = %d\n", state, error); return; } SeriesWork *series = series_of(task); /* get the series of this task */ tutorial_series_context *context = (tutorial_series_context *)series->get_context(); const void *body; size_t body_len; resp->get_parsed_body(&body, &body_len); if (body_len == 0) { fprintf(stderr, "Error: empty http body!"); return; } context->body_len = body_len; WFRedisTask *redis_task = WFTaskFactory::create_redis_task(context->redis_url, RETRY_MAX, redis_callback); std::string value((char *)body, body_len); redis_task->get_req()->set_request("SET", { context->http_url, value }); *series << redis_task; /* equal to series->push_back(redis_task) */ } int main(int argc, char *argv[]) { WFHttpTask *http_task; if (argc != 3) { fprintf(stderr, "USAGE: %s \n", argv[0]); exit(1); } struct tutorial_series_context context; context.success = false; context.http_url = argv[1]; if (strncasecmp(argv[1], "http://", 7) != 0 && strncasecmp(argv[1], "https://", 8) != 0) { context.http_url = "http://" + context.http_url; } context.redis_url = argv[2]; if (strncasecmp(argv[2], "redis://", 8) != 0 && strncasecmp(argv[2], "rediss://", 9) != 0) { context.redis_url = "redis://" + context.redis_url; } http_task = WFTaskFactory::create_http_task(context.http_url, REDIRECT_MAX, RETRY_MAX, http_callback); HttpRequest *req = http_task->get_req(); req->add_header_pair("Accept", "*/*"); req->add_header_pair("User-Agent", "Wget/1.14 (linux-gnu)"); req->add_header_pair("Connection", "close"); /* Limit the http response size to 20M. */ http_task->get_resp()->set_size_limit(20 * 1024 * 1024); /* no more than 30 seconds receiving http response. */ http_task->set_receive_timeout(30 * 1000); WFFacilities::WaitGroup wait_group(1); auto series_callback = [&wait_group](const SeriesWork *series) { tutorial_series_context *context = (tutorial_series_context *) series->get_context(); if (context->success) fprintf(stderr, "Series finished. all success!\n"); else fprintf(stderr, "Series finished. failed!\n"); /* signal the main() to terminate */ wait_group.done(); }; /* Create a series */ SeriesWork *series = Workflow::create_series_work(http_task, series_callback); series->set_context(&context); series->start(); wait_group.wait(); return 0; } workflow-0.11.8/tutorial/tutorial-04-http_echo_server.cc000066400000000000000000000062711476003635400232700ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ #include #include #include #include #include #include #include #include #include #include "workflow/HttpMessage.h" #include "workflow/HttpUtil.h" #include "workflow/WFServer.h" #include "workflow/WFHttpServer.h" #include "workflow/WFFacilities.h" void process(WFHttpTask *server_task) { protocol::HttpRequest *req = server_task->get_req(); protocol::HttpResponse *resp = server_task->get_resp(); long long seq = server_task->get_task_seq(); protocol::HttpHeaderCursor cursor(req); std::string name; std::string value; char buf[8192]; int len; /* Set response message body. */ resp->append_output_body_nocopy("", 6); len = snprintf(buf, 8192, "

%s %s %s

", req->get_method(), req->get_request_uri(), req->get_http_version()); resp->append_output_body(buf, len); while (cursor.next(name, value)) { len = snprintf(buf, 8192, "

%s: %s

", name.c_str(), value.c_str()); resp->append_output_body(buf, len); } resp->append_output_body_nocopy("", 7); /* Set status line if you like. */ resp->set_http_version("HTTP/1.1"); resp->set_status_code("200"); resp->set_reason_phrase("OK"); resp->add_header_pair("Content-Type", "text/html"); resp->add_header_pair("Server", "Sogou WFHttpServer"); if (seq == 9) /* no more than 10 requests on the same connection. */ resp->add_header_pair("Connection", "close"); /* print some log */ char addrstr[128]; struct sockaddr_storage addr; socklen_t l = sizeof addr; unsigned short port = 0; server_task->get_peer_addr((struct sockaddr *)&addr, &l); if (addr.ss_family == AF_INET) { struct sockaddr_in *sin = (struct sockaddr_in *)&addr; inet_ntop(AF_INET, &sin->sin_addr, addrstr, 128); port = ntohs(sin->sin_port); } else if (addr.ss_family == AF_INET6) { struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&addr; inet_ntop(AF_INET6, &sin6->sin6_addr, addrstr, 128); port = ntohs(sin6->sin6_port); } else strcpy(addrstr, "Unknown"); fprintf(stderr, "Peer address: %s:%d, seq: %lld.\n", addrstr, port, seq); } static WFFacilities::WaitGroup wait_group(1); void sig_handler(int signo) { wait_group.done(); } int main(int argc, char *argv[]) { unsigned short port; if (argc != 2) { fprintf(stderr, "USAGE: %s \n", argv[0]); exit(1); } signal(SIGINT, sig_handler); WFHttpServer server(process); port = atoi(argv[1]); if (server.start(port) == 0) { wait_group.wait(); server.stop(); } else { perror("Cannot start server"); exit(1); } return 0; } workflow-0.11.8/tutorial/tutorial-05-http_proxy.cc000066400000000000000000000111771476003635400221470ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ #include #include #include #include #include "workflow/Workflow.h" #include "workflow/HttpMessage.h" #include "workflow/HttpUtil.h" #include "workflow/WFHttpServer.h" #include "workflow/WFFacilities.h" struct tutorial_series_context { std::string url; WFHttpTask *proxy_task; bool is_keep_alive; }; void reply_callback(WFHttpTask *proxy_task) { SeriesWork *series = series_of(proxy_task); tutorial_series_context *context = (tutorial_series_context *)series->get_context(); auto *proxy_resp = proxy_task->get_resp(); size_t size = proxy_resp->get_output_body_size(); if (proxy_task->get_state() == WFT_STATE_SUCCESS) fprintf(stderr, "%s: Success. Http Status: %s, BodyLength: %zu\n", context->url.c_str(), proxy_resp->get_status_code(), size); else /* WFT_STATE_SYS_ERROR*/ fprintf(stderr, "%s: Reply failed: %s, BodyLength: %zu\n", context->url.c_str(), strerror(proxy_task->get_error()), size); } void http_callback(WFHttpTask *task) { int state = task->get_state(); int error = task->get_error(); auto *resp = task->get_resp(); SeriesWork *series = series_of(task); tutorial_series_context *context = (tutorial_series_context *)series->get_context(); auto *proxy_resp = context->proxy_task->get_resp(); if (state == WFT_STATE_SUCCESS) { const void *body; size_t len; /* set a callback for getting reply status. */ context->proxy_task->set_callback(reply_callback); /* Copy the remote webserver's response, to proxy response. */ resp->get_parsed_body(&body, &len); resp->append_output_body_nocopy(body, len); *proxy_resp = std::move(*resp); if (!context->is_keep_alive) proxy_resp->set_header_pair("Connection", "close"); } else { const char *err_string; if (state == WFT_STATE_SYS_ERROR) err_string = strerror(error); else if (state == WFT_STATE_DNS_ERROR) err_string = gai_strerror(error); else if (state == WFT_STATE_SSL_ERROR) err_string = "SSL error"; else /* if (state == WFT_STATE_TASK_ERROR) */ err_string = "URL error (Cannot be a HTTPS proxy)"; fprintf(stderr, "%s: Fetch failed. state = %d, error = %d: %s\n", context->url.c_str(), state, error, err_string); /* As a tutorial, make it simple. And ignore reply status. */ proxy_resp->set_status_code("404"); proxy_resp->append_output_body_nocopy( "404 Not Found.", 27); } } void process(WFHttpTask *proxy_task) { auto *req = proxy_task->get_req(); SeriesWork *series = series_of(proxy_task); WFHttpTask *http_task; /* for requesting remote webserver. */ tutorial_series_context *context = new tutorial_series_context; context->url = req->get_request_uri(); context->proxy_task = proxy_task; series->set_context(context); series->set_callback([](const SeriesWork *series) { delete (tutorial_series_context *)series->get_context(); }); context->is_keep_alive = req->is_keep_alive(); http_task = WFTaskFactory::create_http_task(req->get_request_uri(), 0, 0, http_callback); const void *body; size_t len; /* Copy user's request to the new task's reuqest using std::move() */ req->set_request_uri(http_task->get_req()->get_request_uri()); req->get_parsed_body(&body, &len); req->append_output_body_nocopy(body, len); *http_task->get_req() = std::move(*req); /* also, limit the remote webserver response size. */ http_task->get_resp()->set_size_limit(200 * 1024 * 1024); *series << http_task; } static WFFacilities::WaitGroup wait_group(1); void sig_handler(int signo) { wait_group.done(); } int main(int argc, char *argv[]) { unsigned short port; if (argc != 2) { fprintf(stderr, "USAGE: %s \n", argv[0]); exit(1); } port = atoi(argv[1]); signal(SIGINT, sig_handler); struct WFServerParams params = HTTP_SERVER_PARAMS_DEFAULT; /* for safety, limit request size to 8MB. */ params.request_size_limit = 8 * 1024 * 1024; WFHttpServer server(¶ms, process); if (server.start(port) == 0) { wait_group.wait(); server.stop(); } else { perror("Cannot start server"); exit(1); } return 0; } workflow-0.11.8/tutorial/tutorial-06-parallel_wget.cc000066400000000000000000000054741476003635400225550ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ #include #include #include #include #include "workflow/Workflow.h" #include "workflow/WFTaskFactory.h" #include "workflow/HttpMessage.h" #include "workflow/HttpUtil.h" #include "workflow/WFFacilities.h" using namespace protocol; #define REDIRECT_MAX 5 #define RETRY_MAX 2 struct tutorial_series_context { std::string url; int state; int error; HttpResponse resp; }; void callback(const ParallelWork *pwork) { tutorial_series_context *ctx; const void *body; size_t size; size_t i; for (i = 0; i < pwork->size(); i++) { ctx = (tutorial_series_context *)pwork->series_at(i)->get_context(); printf("%s\n", ctx->url.c_str()); if (ctx->state == WFT_STATE_SUCCESS) { ctx->resp.get_parsed_body(&body, &size); printf("%zu%s\n", size, ctx->resp.is_chunked() ? " chunked" : ""); fwrite(body, 1, size, stdout); printf("\n"); } else printf("ERROR! state = %d, error = %d\n", ctx->state, ctx->error); delete ctx; } } int main(int argc, char *argv[]) { ParallelWork *pwork = Workflow::create_parallel_work(callback); SeriesWork *series; WFHttpTask *task; HttpRequest *req; tutorial_series_context *ctx; int i; for (i = 1; i < argc; i++) { std::string url(argv[i]); if (strncasecmp(argv[i], "http://", 7) != 0 && strncasecmp(argv[i], "https://", 8) != 0) { url = "http://" +url; } task = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX, [](WFHttpTask *task) { tutorial_series_context *ctx = (tutorial_series_context *)series_of(task)->get_context(); ctx->state = task->get_state(); ctx->error = task->get_error(); ctx->resp = std::move(*task->get_resp()); }); req = task->get_req(); req->add_header_pair("Accept", "*/*"); req->add_header_pair("User-Agent", "Wget/1.14 (linux-gnu)"); req->add_header_pair("Connection", "close"); ctx = new tutorial_series_context; ctx->url = std::move(url); series = Workflow::create_series_work(task, nullptr); series->set_context(ctx); pwork->add_series(series); } WFFacilities::WaitGroup wait_group(1); Workflow::start_series_work(pwork, [&wait_group](const SeriesWork *) { wait_group.done(); }); wait_group.wait(); return 0; } workflow-0.11.8/tutorial/tutorial-07-sort_task.cc000066400000000000000000000050371476003635400217400ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ #include #include #include "workflow/WFAlgoTaskFactory.h" #include "workflow/WFFacilities.h" using namespace algorithm; static WFFacilities::WaitGroup wait_group(1); bool use_parallel_sort = false; void callback(WFSortTask *task) { /* Sort task's input and output are identical. */ SortInput *input = task->get_input(); int *first = input->first; int *last = input->last; /* You may remove this output to test speed. */ int *p = first; while (p < last) printf("%d ", *p++); printf("\n"); if (task->user_data == NULL) { auto cmp = [](int a1, int a2)->bool{return a2 *reverse; if (use_parallel_sort) reverse = WFAlgoTaskFactory::create_psort_task("sort", first, last, cmp, callback); else reverse = WFAlgoTaskFactory::create_sort_task("sort", first, last, cmp, callback); reverse->user_data = (void *)1; /* as a flag */ series_of(task)->push_back(reverse); printf("Sort reversely:\n"); } else wait_group.done(); } int main(int argc, char *argv[]) { size_t count; int *array; int *end; size_t i; if (argc != 2 && argc != 3) { fprintf(stderr, "USAGE: %s [p]\n", argv[0]); exit(1); } count = atoi(argv[1]); array = (int *)malloc(count * sizeof (int)); if (!array) { perror("malloc"); exit(1); } if (argc == 3 && (*argv[2] == 'p' || *argv[2] == 'P')) use_parallel_sort = true; for (i = 0; i < count; i++) array[i] = rand() % 65536; end = &array[count]; WFSortTask *task; if (use_parallel_sort) task = WFAlgoTaskFactory::create_psort_task("sort", array, end, callback); else task = WFAlgoTaskFactory::create_sort_task("sort", array, end, callback); if (use_parallel_sort) printf("Start sorting parallelly...\n"); else printf("Start sorting...\n"); printf("Sort result:\n"); task->start(); wait_group.wait(); free(array); return 0; } workflow-0.11.8/tutorial/tutorial-08-matrix_multiply.cc000066400000000000000000000060401476003635400231660ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ #include #include #include #include #include #include "workflow/WFTaskFactory.h" #include "workflow/WFFacilities.h" namespace algorithm { typedef std::vector> Matrix; struct MMInput { Matrix a; Matrix b; }; struct MMOutput { int error; size_t m, n, k; Matrix c; }; bool is_valid_matrix(const Matrix& matrix, size_t& m, size_t& n) { m = n = 0; if (matrix.size() == 0) return true; m = matrix.size(); n = matrix[0].size(); if (n == 0) return false; for (const auto& row : matrix) if (row.size() != n) return false; return true; } void matrix_multiply(const MMInput *in, MMOutput *out) { size_t m1, n1; size_t m2, n2; if (!is_valid_matrix(in->a, m1, n1) || !is_valid_matrix(in->b, m2, n2)) { out->error = EINVAL; return; } if (n1 != m2) { out->error = EINVAL; return; } out->error = 0; out->m = m1; out->n = n2; out->k = n1; out->c.resize(m1); for (size_t i = 0; i < out->m; i++) { out->c[i].resize(n2); for (size_t j = 0; j < out->n; j++) { out->c[i][j] = 0; for (size_t k = 0; k < out->k; k++) out->c[i][j] += in->a[i][k] * in->b[k][j]; } } } } using MMTask = WFThreadTask; using namespace algorithm; void print_matrix(const Matrix& matrix, size_t m, size_t n) { for (size_t i = 0; i < m; i++) { for (size_t j = 0; j < n; j++) printf("\t%8.2lf", matrix[i][j]); printf("\n"); } } void callback(MMTask *task) { auto *input = task->get_input(); auto *output = task->get_output(); assert(task->get_state() == WFT_STATE_SUCCESS); if (output->error) printf("Error: %d %s\n", output->error, strerror(output->error)); else { printf("Matrix A\n"); print_matrix(input->a, output->m, output->k); printf("Matrix B\n"); print_matrix(input->b, output->k, output->n); printf("Matrix A * Matrix B =>\n"); print_matrix(output->c, output->m, output->n); } } int main() { using MMFactory = WFThreadTaskFactory; MMTask *task = MMFactory::create_thread_task("matrix_multiply_task", matrix_multiply, callback); auto *input = task->get_input(); input->a = {{1, 2, 3}, {4, 5, 6}}; input->b = {{7, 8}, {9, 10}, {11, 12}}; WFFacilities::WaitGroup wait_group(1); Workflow::start_series_work(task, [&wait_group](const SeriesWork *) { wait_group.done(); }); wait_group.wait(); return 0; } workflow-0.11.8/tutorial/tutorial-09-http_file_server.cc000066400000000000000000000105651476003635400232770ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ #include #include #include #include #include #include #include #include "workflow/HttpMessage.h" #include "workflow/HttpUtil.h" #include "workflow/WFHttpServer.h" #include "workflow/WFTaskFactory.h" #include "workflow/Workflow.h" #include "workflow/WFFacilities.h" using namespace protocol; void pread_callback(WFFileIOTask *task) { FileIOArgs *args = task->get_args(); long ret = task->get_retval(); HttpResponse *resp = (HttpResponse *)task->user_data; close(args->fd); if (task->get_state() != WFT_STATE_SUCCESS || ret < 0) { resp->set_status_code("503"); resp->append_output_body("503 Internal Server Error."); } else /* Use '_nocopy' carefully. */ resp->append_output_body_nocopy(args->buf, ret); } void process(WFHttpTask *server_task, const char *root) { HttpRequest *req = server_task->get_req(); HttpResponse *resp = server_task->get_resp(); const char *uri = req->get_request_uri(); const char *p = uri; printf("Request-URI: %s\n", uri); while (*p && *p != '?') p++; std::string abs_path(uri, p - uri); abs_path = root + abs_path; if (abs_path.back() == '/') abs_path += "index.html"; resp->add_header_pair("Server", "Sogou C++ Workflow Server"); int fd = open(abs_path.c_str(), O_RDONLY); if (fd >= 0) { size_t size = lseek(fd, 0, SEEK_END); void *buf = malloc(size); /* As an example, assert(buf != NULL); */ WFFileIOTask *pread_task; pread_task = WFTaskFactory::create_pread_task(fd, buf, size, 0, pread_callback); /* To implement a more complicated server, please use series' context * instead of tasks' user_data to pass/store internal data. */ pread_task->user_data = resp; /* pass resp pointer to pread task. */ server_task->user_data = buf; /* to free() in callback() */ server_task->set_callback([](WFHttpTask *t){ free(t->user_data); }); series_of(server_task)->push_back(pread_task); } else { resp->set_status_code("404"); resp->append_output_body("404 Not Found."); } } static WFFacilities::WaitGroup wait_group(1); void sig_handler(int signo) { wait_group.done(); } int main(int argc, char *argv[]) { if (argc != 2 && argc != 3 && argc != 5) { fprintf(stderr, "%s [root path] [cert file] [key file]\n", argv[0]); exit(1); } signal(SIGINT, sig_handler); unsigned short port = atoi(argv[1]); const char *root = (argc >= 3 ? argv[2] : "."); auto&& proc = std::bind(process, std::placeholders::_1, root); WFHttpServer server(proc); std::string scheme; int ret; if (argc == 5) { ret = server.start(port, argv[3], argv[4]); /* https server */ scheme = "https://"; } else { ret = server.start(port); scheme = "http://"; } if (ret < 0) { perror("start server"); exit(1); } /* Test the server. */ auto&& create = [&scheme, port](WFRepeaterTask *)->SubTask *{ char buf[1024]; *buf = '\0'; printf("Input file name: (Ctrl-D to exit): "); scanf("%1023s", buf); if (*buf == '\0') { printf("\n"); return NULL; } std::string url = scheme + "127.0.0.1:" + std::to_string(port) + "/" + buf; WFHttpTask *task = WFTaskFactory::create_http_task(url, 0, 0, [](WFHttpTask *task) { auto *resp = task->get_resp(); if (strcmp(resp->get_status_code(), "200") == 0) { std::string body = protocol::HttpUtil::decode_chunked_body(resp); fwrite(body.c_str(), body.size(), 1, stdout); printf("\n"); } else { printf("%s %s\n", resp->get_status_code(), resp->get_reason_phrase()); } }); return task; }; WFFacilities::WaitGroup wg(1); WFRepeaterTask *repeater; repeater = WFTaskFactory::create_repeater_task(create, [&wg](WFRepeaterTask *) { wg.done(); }); repeater->start(); wg.wait(); server.stop(); return 0; } workflow-0.11.8/tutorial/tutorial-10-user_defined_protocol/000077500000000000000000000000001476003635400237625ustar00rootroot00000000000000workflow-0.11.8/tutorial/tutorial-10-user_defined_protocol/client-uds.cc000066400000000000000000000060311476003635400263400ustar00rootroot00000000000000/* Copyright (c) 2022 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ #include #include #include #include #include "workflow/Workflow.h" #include "workflow/WFTaskFactory.h" #include "workflow/WFFacilities.h" #include "message.h" using WFTutorialTask = WFNetworkTask; using tutorial_callback_t = std::function; using namespace protocol; class MyFactory : public WFTaskFactory { public: static WFTutorialTask *create_tutorial_task(const struct sockaddr *addr, socklen_t addrlen, int retry_max, tutorial_callback_t callback) { using NTF = WFNetworkTaskFactory; WFTutorialTask *task = NTF::create_client_task(TT_TCP, addr, addrlen, retry_max, std::move(callback)); task->set_keep_alive(30 * 1000); return task; } }; int main(int argc, char *argv[]) { const char *path; std::string host; if (argc != 2) { fprintf(stderr, "USAGE: %s \n", argv[0]); exit(1); } path = argv[1]; auto&& create = [path](WFRepeaterTask *)->SubTask *{ char buf[1024]; printf("Input next request string (Ctrl-D to exit): "); *buf = '\0'; scanf("%1023s", buf); size_t body_size = strlen(buf); if (body_size == 0) { printf("\n"); return NULL; } struct sockaddr_un sun = { }; sun.sun_family = AF_UNIX; strncpy(sun.sun_path, path, sizeof sun.sun_path - 1); WFTutorialTask *task = MyFactory::create_tutorial_task( (struct sockaddr *)&sun, sizeof sun, 0, [](WFTutorialTask *task) { int state = task->get_state(); int error = task->get_error(); TutorialResponse *resp = task->get_resp(); void *body; size_t body_size; if (state == WFT_STATE_SUCCESS) { resp->get_message_body_nocopy(&body, &body_size); printf("Server Response: %.*s\n", (int)body_size, (char *)body); } else { const char *str = WFGlobal::get_error_string(state, error); fprintf(stderr, "Error: %s\n", str); } }); task->get_req()->set_message_body(buf, body_size); task->get_resp()->set_size_limit(4 * 1024); return task; }; WFFacilities::WaitGroup wait_group(1); WFRepeaterTask *repeater; repeater = WFTaskFactory::create_repeater_task(std::move(create), nullptr); Workflow::start_series_work(repeater, [&wait_group](const SeriesWork *) { wait_group.done(); }); wait_group.wait(); return 0; } workflow-0.11.8/tutorial/tutorial-10-user_defined_protocol/client.cc000066400000000000000000000055571476003635400255630ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ #include #include #include "workflow/Workflow.h" #include "workflow/WFTaskFactory.h" #include "workflow/WFFacilities.h" #include "message.h" using WFTutorialTask = WFNetworkTask; using tutorial_callback_t = std::function; using namespace protocol; class MyFactory : public WFTaskFactory { public: static WFTutorialTask *create_tutorial_task(const std::string& host, unsigned short port, int retry_max, tutorial_callback_t callback) { using NTF = WFNetworkTaskFactory; WFTutorialTask *task = NTF::create_client_task(TT_TCP, host, port, retry_max, std::move(callback)); task->set_keep_alive(30 * 1000); return task; } }; int main(int argc, char *argv[]) { unsigned short port; std::string host; if (argc != 3) { fprintf(stderr, "USAGE: %s \n", argv[0]); exit(1); } host = argv[1]; port = atoi(argv[2]); auto&& create = [host, port](WFRepeaterTask *)->SubTask *{ char buf[1024]; printf("Input next request string (Ctrl-D to exit): "); *buf = '\0'; scanf("%1023s", buf); size_t body_size = strlen(buf); if (body_size == 0) { printf("\n"); return NULL; } WFTutorialTask *task = MyFactory::create_tutorial_task(host, port, 0, [](WFTutorialTask *task) { int state = task->get_state(); int error = task->get_error(); TutorialResponse *resp = task->get_resp(); void *body; size_t body_size; if (state == WFT_STATE_SUCCESS) { resp->get_message_body_nocopy(&body, &body_size); printf("Server Response: %.*s\n", (int)body_size, (char *)body); } else { const char *str = WFGlobal::get_error_string(state, error); fprintf(stderr, "Error: %s\n", str); } }); task->get_req()->set_message_body(buf, body_size); task->get_resp()->set_size_limit(4 * 1024); return task; }; WFFacilities::WaitGroup wait_group(1); WFRepeaterTask *repeater; repeater = WFTaskFactory::create_repeater_task(std::move(create), nullptr); Workflow::start_series_work(repeater, [&wait_group](const SeriesWork *) { wait_group.done(); }); wait_group.wait(); return 0; } workflow-0.11.8/tutorial/tutorial-10-user_defined_protocol/message.cc000066400000000000000000000061041476003635400257160ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ #include #include #include #include #include #include #include "message.h" namespace protocol { int TutorialMessage::encode(struct iovec vectors[], int max/*max==8192*/) { uint32_t n = htonl(this->body_size); memcpy(this->head, &n, 4); vectors[0].iov_base = this->head; vectors[0].iov_len = 4; vectors[1].iov_base = this->body; vectors[1].iov_len = this->body_size; return 2; /* return the number of vectors used, no more then max. */ } int TutorialMessage::append(const void *buf, size_t size) { if (this->head_received < 4) { size_t head_left; void *p; p = &this->head[head_received]; head_left = 4 - this->head_received; if (size < 4 - this->head_received) { memcpy(p, buf, size); this->head_received += size; return 0; } this->head_received += head_left; memcpy(p, buf, head_left); size -= head_left; buf = (const char *)buf + head_left; p = this->head; this->body_size = ntohl(*(uint32_t *)p); if (this->body_size > this->size_limit) { errno = EMSGSIZE; return -1; } this->body = (char *)malloc(this->body_size); if (!this->body) return -1; this->body_received = 0; } size_t body_left = this->body_size - this->body_received; if (size > body_left) { errno = EBADMSG; return -1; } memcpy(this->body + this->body_received, buf, size); this->body_received += size; if (size < body_left) return 0; return 1; } int TutorialMessage::set_message_body(const void *body, size_t size) { void *p = malloc(size); if (!p) return -1; memcpy(p, body, size); free(this->body); this->body = (char *)p; this->body_size = size; this->head_received = 4; this->body_received = size; return 0; } TutorialMessage::TutorialMessage(TutorialMessage&& msg) : ProtocolMessage(std::move(msg)) { memcpy(this->head, msg.head, 4); this->head_received = msg.head_received; this->body = msg.body; this->body_received = msg.body_received; this->body_size = msg.body_size; msg.head_received = 0; msg.body = NULL; msg.body_size = 0; } TutorialMessage& TutorialMessage::operator = (TutorialMessage&& msg) { if (&msg != this) { *(ProtocolMessage *)this = std::move(msg); memcpy(this->head, msg.head, 4); this->head_received = msg.head_received; this->body = msg.body; this->body_received = msg.body_received; this->body_size = msg.body_size; msg.head_received = 0; msg.body = NULL; msg.body_size = 0; } return *this; } } workflow-0.11.8/tutorial/tutorial-10-user_defined_protocol/message.h000066400000000000000000000030471476003635400255630ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ #ifndef _TUTORIALMESSAGE_H_ #define _TUTORIALMESSAGE_H_ #include #include "workflow/ProtocolMessage.h" namespace protocol { class TutorialMessage : public ProtocolMessage { private: virtual int encode(struct iovec vectors[], int max); virtual int append(const void *buf, size_t size); public: int set_message_body(const void *body, size_t size); void get_message_body_nocopy(void **body, size_t *size) { *body = this->body; *size = this->body_size; } protected: char head[4]; size_t head_received; char *body; size_t body_received; size_t body_size; public: TutorialMessage() { this->head_received = 0; this->body = NULL; this->body_size = 0; } TutorialMessage(TutorialMessage&& msg); TutorialMessage& operator = (TutorialMessage&& msg); virtual ~TutorialMessage() { free(this->body); } }; using TutorialRequest = TutorialMessage; using TutorialResponse = TutorialMessage; } #endif workflow-0.11.8/tutorial/tutorial-10-user_defined_protocol/server-uds.cc000066400000000000000000000041361476003635400263740ustar00rootroot00000000000000/* Copyright (c) 2022 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ #include #include #include #include #include #include #include "workflow/Workflow.h" #include "workflow/WFTaskFactory.h" #include "workflow/WFServer.h" #include "workflow/WFFacilities.h" #include "message.h" using WFTutorialTask = WFNetworkTask; using WFTutorialServer = WFServer; using namespace protocol; void process(WFTutorialTask *task) { TutorialRequest *req = task->get_req(); TutorialResponse *resp = task->get_resp(); void *body; size_t size; size_t i; req->get_message_body_nocopy(&body, &size); for (i = 0; i < size; i++) ((char *)body)[i] = toupper(((char *)body)[i]); resp->set_message_body(body, size); } static WFFacilities::WaitGroup wait_group(1); void sig_handler(int signo) { wait_group.done(); } int main(int argc, char *argv[]) { struct sockaddr_un sun = { }; if (argc != 2) { fprintf(stderr, "USAGE %s \n", argv[0]); exit(1); } sun.sun_family = AF_UNIX; strncpy(sun.sun_path, argv[1], sizeof sun.sun_path - 1); signal(SIGINT, sig_handler); struct WFServerParams params = SERVER_PARAMS_DEFAULT; params.request_size_limit = 4 * 1024; WFTutorialServer server(¶ms, process); if (server.start((struct sockaddr *)&sun, sizeof sun) == 0) { wait_group.wait(); server.stop(); } else { perror("server.start"); exit(1); } return 0; } workflow-0.11.8/tutorial/tutorial-10-user_defined_protocol/server.cc000066400000000000000000000037771476003635400256150ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ #include #include #include #include #include "workflow/Workflow.h" #include "workflow/WFTaskFactory.h" #include "workflow/WFServer.h" #include "workflow/WFFacilities.h" #include "message.h" using WFTutorialTask = WFNetworkTask; using WFTutorialServer = WFServer; using namespace protocol; void process(WFTutorialTask *task) { TutorialRequest *req = task->get_req(); TutorialResponse *resp = task->get_resp(); void *body; size_t size; size_t i; req->get_message_body_nocopy(&body, &size); for (i = 0; i < size; i++) ((char *)body)[i] = toupper(((char *)body)[i]); resp->set_message_body(body, size); } static WFFacilities::WaitGroup wait_group(1); void sig_handler(int signo) { wait_group.done(); } int main(int argc, char *argv[]) { unsigned short port; if (argc != 2) { fprintf(stderr, "USAGE %s \n", argv[0]); exit(1); } port = atoi(argv[1]); signal(SIGINT, sig_handler); struct WFServerParams params = SERVER_PARAMS_DEFAULT; params.request_size_limit = 4 * 1024; WFTutorialServer server(¶ms, process); if (server.start(AF_INET6, port) == 0 || server.start(AF_INET, port) == 0) { wait_group.wait(); server.stop(); } else { perror("server.start"); exit(1); } return 0; } workflow-0.11.8/tutorial/tutorial-10-user_defined_protocol/xmake.lua000066400000000000000000000010661476003635400255750ustar00rootroot00000000000000add_deps("workflow") target("user_defined_message") set_kind("object") add_files("message.cc") target("user_defined_server") set_kind("binary") add_files("server.cc") add_deps("user_defined_message") target("server-uds") set_kind("binary") add_files("server-uds.cc") add_deps("user_defined_message") target("user_defined_client") set_kind("binary") add_files("client.cc") add_deps("user_defined_message") target("client-uds") set_kind("binary") add_files("client-uds.cc") add_deps("user_defined_message") workflow-0.11.8/tutorial/tutorial-11-graph_task.cc000066400000000000000000000050271476003635400220440ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ #include #include "workflow/WFTaskFactory.h" #include "workflow/WFGraphTask.h" #include "workflow/HttpMessage.h" #include "workflow/WFFacilities.h" using namespace protocol; static WFFacilities::WaitGroup wait_group(1); void go_func(const size_t *size1, const size_t *size2) { printf("page1 size = %zu, page2 size = %zu\n", *size1, *size2); } void http_callback(WFHttpTask *task) { size_t *size = (size_t *)task->user_data; const void *body; if (task->get_state() == WFT_STATE_SUCCESS) task->get_resp()->get_parsed_body(&body, size); else *size = (size_t)-1; } #define REDIRECT_MAX 3 #define RETRY_MAX 1 int main() { WFTimerTask *timer; WFHttpTask *http_task1; WFHttpTask *http_task2; WFGoTask *go_task; size_t size1; size_t size2; timer = WFTaskFactory::create_timer_task(1000000, [](WFTimerTask *) { printf("timer task complete(1s).\n"); }); /* Http task1 */ http_task1 = WFTaskFactory::create_http_task("https://www.sogou.com/", REDIRECT_MAX, RETRY_MAX, http_callback); http_task1->user_data = &size1; /* Http task2 */ http_task2 = WFTaskFactory::create_http_task("https://www.baidu.com/", REDIRECT_MAX, RETRY_MAX, http_callback); http_task2->user_data = &size2; /* go task will print the http pages size */ go_task = WFTaskFactory::create_go_task("go", go_func, &size1, &size2); /* Create a graph. Graph is also a kind of task */ WFGraphTask *graph = WFTaskFactory::create_graph_task([](WFGraphTask *) { printf("Graph task complete. Wakeup main process\n"); wait_group.done(); }); /* Create graph nodes */ WFGraphNode& a = graph->create_graph_node(timer); WFGraphNode& b = graph->create_graph_node(http_task1); WFGraphNode& c = graph->create_graph_node(http_task2); WFGraphNode& d = graph->create_graph_node(go_task); /* Build the graph */ a-->b; a-->c; b-->d; c-->d; graph->start(); wait_group.wait(); return 0; } workflow-0.11.8/tutorial/tutorial-12-mysql_cli.cc000066400000000000000000000163031476003635400217150ustar00rootroot00000000000000/* Copyright (c) 2019 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) */ #include #include #include #include #include #include #include #include "workflow/Workflow.h" #include "workflow/WFTaskFactory.h" #include "workflow/MySQLResult.h" #include "workflow/WFFacilities.h" using namespace protocol; #define RETRY_MAX 0 volatile bool stop_flag; void mysql_callback(WFMySQLTask *task); void get_next_cmd(WFMySQLTask *task) { int len; char query[4096]; WFMySQLTask *next_task; fprintf(stderr, "mysql> "); while ((fgets(query, 4096, stdin)) && stop_flag == false) { len = strlen(query); if (len > 0 && query[len - 1] == '\n') query[len - 1] = '\0'; if (strncmp(query, "quit", len) == 0 || strncmp(query, "exit", len) == 0) { fprintf(stderr, "Bye\n"); return; } if (len == 0 || strncmp(query, "\0", len) == 0) { fprintf(stderr, "mysql> "); continue; } std::string *url = (std::string *)series_of(task)->get_context(); next_task = WFTaskFactory::create_mysql_task(*url, RETRY_MAX, mysql_callback); next_task->get_req()->set_query(query); series_of(task)->push_back(next_task); break; } return; } void mysql_callback(WFMySQLTask *task) { MySQLResponse *resp = task->get_resp(); MySQLResultCursor cursor(resp); const MySQLField *const *fields; std::vector arr; if (task->get_state() != WFT_STATE_SUCCESS) { fprintf(stderr, "error msg: %s\n", WFGlobal::get_error_string(task->get_state(), task->get_error())); return; } do { if (cursor.get_cursor_status() != MYSQL_STATUS_GET_RESULT && cursor.get_cursor_status() != MYSQL_STATUS_OK) { break; } fprintf(stderr, "---------------- RESULT SET ----------------\n"); if (cursor.get_cursor_status() == MYSQL_STATUS_GET_RESULT) { fprintf(stderr, "cursor_status=%d field_count=%u rows_count=%u\n", cursor.get_cursor_status(), cursor.get_field_count(), cursor.get_rows_count()); //nocopy api fields = cursor.fetch_fields(); for (int i = 0; i < cursor.get_field_count(); i++) { if (i == 0) { fprintf(stderr, "db=%s table=%s\n", fields[i]->get_db().c_str(), fields[i]->get_table().c_str()); fprintf(stderr, " ---------- COLUMNS ----------\n"); } fprintf(stderr, " name[%s] type[%s]\n", fields[i]->get_name().c_str(), datatype2str(fields[i]->get_data_type())); } fprintf(stderr, " _________ COLUMNS END _________\n\n"); while (cursor.fetch_row(arr)) { fprintf(stderr, " ------------ ROW ------------\n"); for (size_t i = 0; i < arr.size(); i++) { fprintf(stderr, " [%s][%s]", fields[i]->get_name().c_str(), datatype2str(arr[i].get_data_type())); if (arr[i].is_string()) { std::string res = arr[i].as_string(); if (res.length() == 0) fprintf(stderr, "[\"\"]\n"); else fprintf(stderr, "[%s]\n", res.c_str()); } else if (arr[i].is_int()) { fprintf(stderr, "[%d]\n", arr[i].as_int()); } else if (arr[i].is_ulonglong()) { fprintf(stderr, "[%llu]\n", arr[i].as_ulonglong()); } else if (arr[i].is_float()) { const void *ptr; size_t len; int data_type; arr[i].get_cell_nocopy(&ptr, &len, &data_type); size_t pos; for (pos = 0; pos < len; pos++) if (*((const char *)ptr + pos) == '.') break; if (pos != len) pos = len - pos - 1; else pos = 0; fprintf(stderr, "[%.*f]\n", (int)pos, arr[i].as_float()); } else if (arr[i].is_double()) { const void *ptr; size_t len; int data_type; arr[i].get_cell_nocopy(&ptr, &len, &data_type); size_t pos; for (pos = 0; pos < len; pos++) if (*((const char *)ptr + pos) == '.') break; if (pos != len) pos = len - pos - 1; else pos= 0; fprintf(stderr, "[%.*lf]\n", (int)pos, arr[i].as_double()); } else if (arr[i].is_date()) { fprintf(stderr, "[%s]\n", arr[i].as_string().c_str()); } else if (arr[i].is_time()) { fprintf(stderr, "[%s]\n", arr[i].as_string().c_str()); } else if (arr[i].is_datetime()) { fprintf(stderr, "[%s]\n", arr[i].as_string().c_str()); } else if (arr[i].is_null()) { fprintf(stderr, "[NULL]\n"); } else { std::string res = arr[i].as_binary_string(); if (res.length() == 0) fprintf(stderr, "[\"\"]\n"); else fprintf(stderr, "[%s]\n", res.c_str()); } } fprintf(stderr, " __________ ROW END __________\n"); } } else if (cursor.get_cursor_status() == MYSQL_STATUS_OK) { fprintf(stderr, " OK. %llu ", cursor.get_affected_rows()); if (cursor.get_affected_rows() == 1) fprintf(stderr, "row "); else fprintf(stderr, "rows "); fprintf(stderr, "affected. %d warnings. insert_id=%llu. %s\n", cursor.get_warnings(), cursor.get_insert_id(), cursor.get_info().c_str()); } fprintf(stderr, "________________ RESULT SET END ________________\n\n"); } while (cursor.next_result_set()); if (resp->get_packet_type() == MYSQL_PACKET_ERROR) { fprintf(stderr, "ERROR. error_code=%d %s\n", task->get_resp()->get_error_code(), task->get_resp()->get_error_msg().c_str()); } else if (resp->get_packet_type() == MYSQL_PACKET_OK) // just check origin APIs { fprintf(stderr, "OK. %llu ", task->get_resp()->get_affected_rows()); if (task->get_resp()->get_affected_rows() == 1) fprintf(stderr, "row "); else fprintf(stderr, "rows "); fprintf(stderr, "affected. %d warnings. insert_id=%llu. %s\n", task->get_resp()->get_warnings(), task->get_resp()->get_last_insert_id(), task->get_resp()->get_info().c_str()); } get_next_cmd(task); return; } static void sighandler(int signo) { stop_flag = true; } int main(int argc, char *argv[]) { WFMySQLTask *task; if (argc != 2) { fprintf(stderr, "USAGE: %s \n" " url format: mysql://root:password@host:port/dbname?character_set=charset\n" " example: mysql://root@test.mysql.com/test\n", argv[0]); return 0; } signal(SIGINT, sighandler); signal(SIGTERM, sighandler); std::string url = argv[1]; if (strncasecmp(argv[1], "mysql://", 8) != 0 && strncasecmp(argv[1], "mysqls://", 9) != 0) { url = "mysql://" + url; } const char *query = "show databases"; stop_flag = false; task = WFTaskFactory::create_mysql_task(url, RETRY_MAX, mysql_callback); task->get_req()->set_query(query); WFFacilities::WaitGroup wait_group(1); SeriesWork *series = Workflow::create_series_work(task, [&wait_group](const SeriesWork *series) { wait_group.done(); }); series->set_context(&url); series->start(); wait_group.wait(); return 0; } workflow-0.11.8/tutorial/tutorial-13-kafka_cli.cc000066400000000000000000000141401476003635400216230ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Wang Zhulei (wangzhulei@sogou-inc.com) */ #include #include #include #include #include #include #include #include "workflow/WFKafkaClient.h" #include "workflow/KafkaMessage.h" #include "workflow/KafkaResult.h" #include "workflow/WFTaskFactory.h" #include "workflow/WFFacilities.h" #include "workflow/WFGlobal.h" using namespace protocol; static WFFacilities::WaitGroup wait_group(1); std::string url; bool no_cgroup = false; WFKafkaClient client; void kafka_callback(WFKafkaTask *task) { int state = task->get_state(); int error = task->get_error(); if (state != WFT_STATE_SUCCESS) { fprintf(stderr, "error msg: %s\n", WFGlobal::get_error_string(state, error)); fprintf(stderr, "Failed. Press Ctrl-C to exit.\n"); client.deinit(); wait_group.done(); return; } WFKafkaTask *next_task = NULL; std::vector> records; std::vector toppars; int api_type = task->get_api_type(); protocol::KafkaResult new_result; switch (api_type) { case Kafka_Produce: task->get_result()->fetch_records(records); for (const auto &v : records) { for (const auto &w: v) { const void *value; size_t value_len; w->get_value(&value, &value_len); printf("produce\ttopic: %s, partition: %d, status: %d, \ offset: %lld, val_len: %zu\n", w->get_topic(), w->get_partition(), w->get_status(), w->get_offset(), value_len); } } break; case Kafka_Fetch: new_result = std::move(*task->get_result()); new_result.fetch_records(records); if (!records.empty()) { if (!no_cgroup) next_task = client.create_kafka_task("api=commit", 3, kafka_callback); std::string out; for (const auto &v : records) { if (v.empty()) continue; char fn[1024]; snprintf(fn, 1024, "/tmp/kafka.%s.%d.%llu", v.back()->get_topic(), v.back()->get_partition(), v.back()->get_offset()); FILE *fp = fopen(fn, "w+"); long long offset = 0; int partition = 0; std::string topic; for (const auto &w : v) { const void *value; size_t value_len; w->get_value(&value, &value_len); if (fp) fwrite(value, value_len, 1, fp); offset = w->get_offset(); partition = w->get_partition(); topic = w->get_topic(); if (!no_cgroup) next_task->add_commit_record(*w); } if (!topic.empty()) { out += "topic: " + topic; out += ",partition: " + std::to_string(partition); out += ",offset: " + std::to_string(offset) + ";"; } if (fp) fclose(fp); } printf("fetch\t%s\n", out.c_str()); if (!no_cgroup) series_of(task)->push_back(next_task); } break; case Kafka_OffsetCommit: task->get_result()->fetch_toppars(toppars); if (!toppars.empty()) { for (const auto& v : toppars) { printf("commit\ttopic: %s, partition: %d, \ offset: %llu, error: %d\n", v->get_topic(), v->get_partition(), v->get_offset(), v->get_error()); } } next_task = client.create_leavegroup_task(3, kafka_callback); series_of(task)->push_back(next_task); break; case Kafka_LeaveGroup: printf("leavegroup callback\n"); break; default: break; } if (!next_task) { client.deinit(); wait_group.done(); } } void sig_handler(int signo) { } int main(int argc, char *argv[]) { if (argc < 3) { fprintf(stderr, "USAGE: %s url

[compress_type/d]\n", argv[0]); exit(1); } signal(SIGINT, sig_handler); url = argv[1]; if (strncmp(argv[1], "kafka://", 8) != 0 && strncmp(argv[1], "kafkas://", 9) != 0) { url = "kafka://" + url; } char buf[512 * 1024]; WFKafkaTask *task; if (argv[2][0] == 'p') { int compress_type = Kafka_NoCompress; if (argc > 3) compress_type = atoi(argv[3]); if (compress_type > Kafka_Zstd) exit(1); if (client.init(url) < 0) { perror("client.init"); exit(1); } task = client.create_kafka_task("api=produce", 3, kafka_callback); KafkaConfig config; KafkaRecord record; config.set_compress_type(compress_type); config.set_client_id("workflow"); task->set_config(std::move(config)); for (size_t i = 0; i < sizeof (buf); ++i) buf[i] = '1' + rand() % 128; record.set_key("key1", strlen("key1")); record.set_value(buf, sizeof (buf)); record.add_header_pair("hk1", 3, "hv1", 3); task->add_produce_record("workflow_test1", -1, std::move(record)); record.set_key("key2", strlen("key2")); record.set_value(buf, sizeof (buf)); record.add_header_pair("hk2", 3, "hv2", 3); task->add_produce_record("workflow_test2", -1, std::move(record)); } else if (argv[2][0] == 'c') { if (argc > 3 && argv[3][0] == 'd') { if (client.init(url) < 0) { perror("client.init"); exit(1); } task = client.create_kafka_task("api=fetch", 3, kafka_callback); KafkaToppar toppar; toppar.set_topic_partition("workflow_test1", 0); toppar.set_offset(0); task->add_toppar(toppar); toppar.set_topic_partition("workflow_test2", 0); toppar.set_offset(1); task->add_toppar(toppar); no_cgroup = true; } else { if (client.init(url, "workflow_group") < 0) { perror("client.init"); exit(1); } task = client.create_kafka_task("topic=workflow_test1&topic=workflow_test2&api=fetch", 3, kafka_callback); } KafkaConfig config; config.set_client_id("workflow"); task->set_config(std::move(config)); } else { fprintf(stderr, "USAGE: %s url

[compress_type/d]\n", argv[0]); exit(1); } task->start(); wait_group.wait(); return 0; } workflow-0.11.8/tutorial/tutorial-14-consul_cli.cc000066400000000000000000000165351476003635400220640ustar00rootroot00000000000000/* Copyright (c) 2020 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Wang Zhenpeng (wangzhenpeng@sogou-inc.com) */ #include #include #include #include #include #include #include #include #include "workflow/WFConsulClient.h" #include "workflow/ConsulDataTypes.h" #include "workflow/WFTaskFactory.h" #include "workflow/WFFacilities.h" #include "workflow/HttpMessage.h" #include "workflow/WFGlobal.h" using namespace protocol; static WFFacilities::WaitGroup wait_group(1); std::string url; WFConsulClient client; void print_discover_result(std::vector& discover_result) { for (const auto& instance : discover_result) { fprintf(stderr, "%s", "discover_instance\n"); fprintf(stderr, "node_id:%s\n", instance.node_id.c_str()); fprintf(stderr, "node_name:%s\n", instance.node_name.c_str()); fprintf(stderr, "node_address:%s\n", instance.node_address.c_str()); fprintf(stderr, "dc:%s\n", instance.dc.c_str()); const std::map& node_meta = instance.node_meta; for (const auto& meta_kv : node_meta) { fprintf(stderr, "node_meta:%s = %s\n", meta_kv.first.c_str(), meta_kv.second.c_str()); } fprintf(stderr, "create_index:%lld\n", instance.create_index); fprintf(stderr, "modify_index:%lld\n", instance.modify_index); fprintf(stderr, "service_id:%s\n", instance.service.service_id.c_str()); fprintf(stderr, "service_name:%s\n", instance.service.service_name.c_str()); fprintf(stderr, "service_namespace:%s\n", instance.service.service_namespace.c_str()); fprintf(stderr, "service_address:%s\n", instance.service.service_address.first.c_str()); fprintf(stderr, "service_port:%d\n", instance.service.service_address.second); fprintf(stderr, "service_tag_override:%d\n", instance.service.tag_override); fprintf(stderr, "%s", "service_tags:"); const std::vector& tags = instance.service.tags; for (const auto& tag : tags) { fprintf(stderr, "%s,", tag.c_str()); } fprintf(stderr, "\n"); const std::map& service_meta = instance.service.meta; for (const auto& meta_kv : service_meta) { fprintf(stderr, "service_meta:%s = %s\n", meta_kv.first.c_str(), meta_kv.second.c_str()); } fprintf(stderr, "lan:%s:%d\n", instance.service.lan.first.c_str(), instance.service.lan.second); fprintf(stderr, "lan_ipv4:%s:%d\n", instance.service.lan_ipv4.first.c_str(), instance.service.lan_ipv4.second); fprintf(stderr, "lan_ipv6:%s:%d\n", instance.service.lan_ipv6.first.c_str(), instance.service.lan_ipv6.second); fprintf(stderr, "wan:%s:%d\n", instance.service.wan.first.c_str(), instance.service.wan.second); fprintf(stderr, "wan_ipv4:%s:%d\n", instance.service.wan_ipv4.first.c_str(), instance.service.wan_ipv4.second); fprintf(stderr, "wan_ipv6:%s:%d\n", instance.service.wan_ipv6.first.c_str(), instance.service.wan_ipv6.second); fprintf(stderr, "check_id:%s\n", instance.check_id.c_str()); fprintf(stderr, "check_name:%s\n", instance.check_name.c_str()); fprintf(stderr, "check_notes:%s\n", instance.check_notes.c_str()); fprintf(stderr, "check_output:%s\n", instance.check_output.c_str()); fprintf(stderr, "check_status:%s\n", instance.check_status.c_str()); fprintf(stderr, "check_type:%s\n", instance.check_type.c_str()); } } void print_list_service_result( std::vector& list_service_result) { for (const auto& instance : list_service_result) { fprintf(stderr, "service name:%s tags:", instance.service_name.c_str()); std::string tags_log; for (const auto& tag : instance.tags) { tags_log += tag; tags_log += ","; } if (tags_log.size() > 0) tags_log.pop_back(); fprintf(stderr, "%s\n", tags_log.c_str()); } } void consul_callback(WFConsulTask *task) { int state = task->get_state(); int error = task->get_error(); if (state != WFT_STATE_SUCCESS) { fprintf(stderr, "error:%d, error msg:%s\n", error, WFGlobal::get_error_string(state, error)); fprintf(stderr, "Failed. Press Ctrl-C to exit.\n"); wait_group.done(); return; } int api_type = task->get_api_type(); std::vector dis_result; std::vector list_service_result; switch (api_type) { case CONSUL_API_TYPE_DISCOVER: fprintf(stderr, "discover ok\n"); fprintf(stderr, "consul-index:%lld\n", task->get_consul_index()); if (task->get_discover_result(dis_result)) print_discover_result(dis_result); else fprintf(stderr, "error:%d\n", task->get_error()); break; case CONSUL_API_TYPE_LIST_SERVICE: fprintf(stderr, "list service ok\n"); if (task->get_list_service_result(list_service_result)) print_list_service_result(list_service_result); else fprintf(stderr, "error:%d\n", task->get_error()); break; case CONSUL_API_TYPE_REGISTER: fprintf(stderr, "register ok\n"); break; case CONSUL_API_TYPE_DEREGISTER: fprintf(stderr, "deregister ok\n"); break; default: break; } wait_group.done(); } void sig_handler(int signo) { } int main(int argc, char *argv[]) { if (argc < 3) { fprintf(stderr, "USAGE: %s url type(discover/register/deregister)

\n", argv[0]); exit(1); } signal(SIGINT, sig_handler); url = argv[1]; if (strncmp(argv[1], "http://", 7) != 0) url = "http://" + url; ConsulConfig config; config.set_token("cd125427-3fd1-f326-bf46-fbce06cc9003"); config.set_health_check(true); // http health check config.set_check_http_url("http://127.0.0.1:8000/health_check/sd"); // config.add_http_header("Accept", {"text/html", "application/xml"}); // tcp health check //config.set_check_tcp("127.0.0.1:80"); client.init(url, config); WFConsulTask *task; if (0 == strcmp(argv[2], "discover")) { task = client.create_discover_task("", "dev-wf_test_service_1", 3, consul_callback); config.set_blocking_query(true); } else if (0 == strcmp(argv[2], "list_service")) { task = client.create_list_service_task("", 3, consul_callback); } else if (0 == strcmp(argv[2], "register")) { task = client.create_register_task("", "dev-wf_test_service_1", "wf_test_service_id_2", 3, consul_callback); protocol::ConsulService service; service.tags.emplace_back("tag1"); service.tags.emplace_back("tag2"); service.service_address.first = "127.0.0.1"; service.service_address.second = 8000; service.meta["mk1"] = "mv1"; service.meta["mk2"] = "mv2"; service.tag_override = true; task->set_service(&service); } else if (0 == strcmp(argv[2], "deregister")) { task = client.create_deregister_task("", "wf_test_service_id_2", 3, consul_callback); } else { fprintf(stderr, "USAGE: %s url

[compress_type/d]\n", argv[0]); exit(1); } task->start(); wait_group.wait(); return 0; } workflow-0.11.8/tutorial/tutorial-15-name_service.cc000066400000000000000000000105011476003635400223560ustar00rootroot00000000000000/* Copyright (c) 2021 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Author: Xie Han (xiehan@sogou-inc.com;63350856@qq.com) */ #include #include #include #include #include #include "workflow/WFGlobal.h" #include "workflow/WFNameService.h" #include "workflow/WFTaskFactory.h" #include "workflow/WFFacilities.h" #include "workflow/HttpUtil.h" // The example domonstrate the simplest user defined naming policy. /* 'MyNSPolicy' is a naming policy, which use local file for naming. * The format of naming file is similar to 'hosts' file, but we allow * domain name and IP address as destination. For example: * * 127.0.0.1 localhost * 127.0.0.1 mydomain # another alias for 127.0.0.1 * www.sogou.com sogou # sogou -> www.sogou.com */ class MyNSPolicy : public WFNSPolicy { public: WFRouterTask *create_router_task(const struct WFNSParams *params, router_callback_t callback) override; private: std::string path; private: std::string read_from_fp(FILE *fp, const char *name); std::string parse_line(char *p, const char *name); public: MyNSPolicy(const char *naming_file) : path(naming_file) { } }; std::string MyNSPolicy::parse_line(char *p, const char *name) { const char *dest = NULL; char *start; start = p; while (*start != '\0' && *start != '#') start++; *start = '\0'; while (1) { while (isspace(*p)) p++; start = p; while (*p != '\0' && !isspace(*p)) p++; if (start == p) break; if (*p != '\0') *p++ = '\0'; if (dest == NULL) { dest = start; continue; } if (strcasecmp(name, start) == 0) return std::string(dest); } return std::string(); } std::string MyNSPolicy::read_from_fp(FILE *fp, const char *name) { char *line = NULL; size_t bufsize = 0; std::string result; while (getline(&line, &bufsize, fp) > 0) { result = this->parse_line(line, name); if (result.size() > 0) break; } free(line); return result; } WFRouterTask *MyNSPolicy::create_router_task(const struct WFNSParams *params, router_callback_t callback) { WFDnsResolver *dns_resolver = WFGlobal::get_dns_resolver(); if (params->uri.host) { FILE *fp = fopen(this->path.c_str(), "r"); if (fp) { std::string dest = this->read_from_fp(fp, params->uri.host); if (dest.size() > 0) { /* Update the uri structure's 'host' field directly. * You can also update the 'port' field if needed. */ free(params->uri.host); params->uri.host = strdup(dest.c_str()); } fclose(fp); } } /* Simply, use the global dns resolver to create a router task. */ return dns_resolver->create_router_task(params, std::move(callback)); } int main(int argc, char *argv[]) { if (argc != 3) { fprintf(stderr, "USAGE: %s \n", argv[0]); exit(1); } ParsedURI uri; URIParser::parse(argv[1], uri); char *name = uri.host; if (name == NULL) { fprintf(stderr, "Invalid http URI\n"); exit(1); } /* Create an naming policy. */ MyNSPolicy *policy = new MyNSPolicy(argv[2]); /* Get the global name service object.*/ WFNameService *ns = WFGlobal::get_name_service(); /* Add the our name with policy to global name service. * You can add mutilply names with one policy object. */ ns->add_policy(name, policy); WFFacilities::WaitGroup wg(1); WFHttpTask *task = WFTaskFactory::create_http_task(argv[1], 2, 3, [&wg](WFHttpTask *task) { int state = task->get_state(); int error = task->get_error(); if (state != WFT_STATE_SUCCESS) { fprintf(stderr, "error: %s\n", WFGlobal::get_error_string(state, error)); } else { auto *r = task->get_resp(); std::string body = protocol::HttpUtil::decode_chunked_body(r); fwrite(body.c_str(), 1, body.size(), stdout); } wg.done(); }); task->start(); wg.wait(); /* clean up */ ns->del_policy(name); delete policy; return 0; } workflow-0.11.8/tutorial/tutorial-16-graceful_restart/000077500000000000000000000000001476003635400227475ustar00rootroot00000000000000workflow-0.11.8/tutorial/tutorial-16-graceful_restart/bootstrap.c000066400000000000000000000060511476003635400251320ustar00rootroot00000000000000/* Copyright (c) 2022 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) */ #include #include #include #include #include #include #include #include #include int flag = 0; void sig_handler(int signo) { if (signo == SIGUSR1) flag = 1; else if (signo == SIGINT || signo == SIGTERM) flag = 2; } int main(int argc, const char *argv[]) { if (argc != 3) { fprintf(stderr, "USAGE: %s EXEC_PROCESS PORT\n" "Bootstrap for workflow server to restart gracefully.\n", argv[0]); exit(1); } unsigned short port = atoi(argv[2]); int listen_fd = socket(AF_INET, SOCK_STREAM, 0); struct sockaddr_in sin; memset(&sin, 0, sizeof sin); sin.sin_family = AF_INET; sin.sin_port = htons(port); sin.sin_addr.s_addr = htonl(INADDR_ANY); if (bind(listen_fd, (struct sockaddr *)&sin, sizeof sin) < 0) { close(listen_fd); perror("bind error"); exit(1); } pid_t pid; int pipe_fd[2]; ssize_t len; char buf[100]; int status; int ret; char listen_fd_str[10] = { 0 }; char write_fd_str[10] = { 0 }; sprintf(listen_fd_str, "%d", listen_fd); signal(SIGINT, sig_handler); signal(SIGTERM, sig_handler); signal(SIGUSR1, sig_handler); while (flag < 2) { if (pipe(pipe_fd) == -1) { perror("open pipe error"); exit(1); } memset(write_fd_str, 0, sizeof write_fd_str); sprintf(write_fd_str, "%d", pipe_fd[1]); pid = fork(); if (pid < 0) { perror("fork error"); close(pipe_fd[0]); close(pipe_fd[1]); break; } else if (pid == 0) { close(pipe_fd[0]); execlp(argv[1], argv[1], listen_fd_str, write_fd_str, NULL); } else { close(pipe_fd[1]); status = 0; ret = 0; flag = 0; fprintf(stderr, "Bootstrap daemon running with server pid-%d. " "Send SIGUSR1 to RESTART or SIGTERM to STOP.\n", pid); while (1) { ret = waitpid(pid, &status, WNOHANG); if (ret == -1 || !WIFEXITED(status) || flag != 0) break; sleep(3); } if (ret != -1 && WIFEXITED(status)) { signal(SIGCHLD, SIG_IGN); kill(pid, SIGUSR1); fprintf(stderr, "Bootstrap daemon SIGUSR1 to pid-%ld %sing.\n", (long)pid, flag == 1 ? "restart" : "stop"); len = read(pipe_fd[0], buf, 7); fprintf(stderr, "Bootstrap server served %*s.\n", (int)len, buf); } else { fprintf(stderr, "child exit. status = %d, waitpid ret = %d\n", WEXITSTATUS(status), ret); flag = 2; } close(pipe_fd[0]); } } close(listen_fd); return 0; } workflow-0.11.8/tutorial/tutorial-16-graceful_restart/server.cc000066400000000000000000000027331476003635400245710ustar00rootroot00000000000000/* Copyright (c) 2022 Sogou, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Authors: Li Yingxin (liyingxin@sogou-inc.com) */ #include #include #include #include #include "workflow/WFFacilities.h" #include "workflow/WFHttpServer.h" static WFFacilities::WaitGroup wait_group(1); void sig_handler(int signo) { wait_group.done(); } int main(int argc, const char *argv[]) { if (argc != 3) { fprintf(stderr, "USAGE: %s listen_fd pipe_fd\n", argv[0]); exit(1); } int listen_fd = atoi(argv[1]); int pipe_fd = atoi(argv[2]); signal(SIGUSR1, sig_handler); WFHttpServer server([](WFHttpTask *task) { task->get_resp()->append_output_body("Hello World!"); }); if (server.serve(listen_fd) == 0) { wait_group.wait(); server.shutdown(); write(pipe_fd, "success", strlen("success")); server.wait_finish(); } else write(pipe_fd, "failed ", strlen("failed ")); close(pipe_fd); close(listen_fd); return 0; } workflow-0.11.8/tutorial/tutorial-16-graceful_restart/xmake.lua000066400000000000000000000002561476003635400245620ustar00rootroot00000000000000target("bootstrap") set_kind("binary") add_files("bootstrap.c") target("bootstrap_server") set_kind("binary") add_files("server.cc") add_deps("workflow")workflow-0.11.8/tutorial/tutorial-17-dns_cli.cc000066400000000000000000000070241476003635400213410ustar00rootroot00000000000000#include #include #include #include #include #include "workflow/DnsUtil.h" #include "workflow/WFDnsClient.h" #include "workflow/WFFacilities.h" static const std::map qtype_map = { {"A", DNS_TYPE_A }, {"AAAA", DNS_TYPE_AAAA }, {"CNAME", DNS_TYPE_CNAME }, {"SOA", DNS_TYPE_SOA }, {"NS", DNS_TYPE_NS }, {"SRV", DNS_TYPE_SRV }, {"MX", DNS_TYPE_MX } }; WFFacilities::WaitGroup wait_group(1); void show_result(protocol::DnsResultCursor& cursor) { char information[1024]; const char *info; struct dns_record *record; struct dns_record_soa *soa; struct dns_record_srv *srv; struct dns_record_mx *mx; while(cursor.next(&record)) { switch (record->type) { case DNS_TYPE_A: info = inet_ntop(AF_INET, record->rdata, information, 64); break; case DNS_TYPE_AAAA: info = inet_ntop(AF_INET6, record->rdata, information, 64); break; case DNS_TYPE_NS: case DNS_TYPE_CNAME: case DNS_TYPE_PTR: info = (const char *)(record->rdata); break; case DNS_TYPE_SOA: soa = (struct dns_record_soa *)(record->rdata); sprintf(information, "%s %s %u %d %d %d %u", soa->mname, soa->rname, soa->serial, soa->refresh, soa->retry, soa->expire, soa->minimum ); info = information; break; case DNS_TYPE_SRV: srv = (struct dns_record_srv *)(record->rdata); sprintf(information, "%u %u %u %s", srv->priority, srv->weight, srv->port, srv->target ); info = information; break; case DNS_TYPE_MX: mx = (struct dns_record_mx *)(record->rdata); sprintf(information, "%d %s", mx->preference, mx->exchange); info = information; break; default: info = "Unknown"; break; } printf("%s\t%d\t%s\t%s\t%s\n", record->name, record->ttl, dns_class2str(record->rclass), dns_type2str(record->type), info ); } printf("\n"); } void dns_callback(WFDnsTask *task) { int state = task->get_state(); int error = task->get_error(); auto *resp = task->get_resp(); if (state != WFT_STATE_SUCCESS) { printf("State: %d, Error: %d\n", state, error); printf("Error: %s\n", WFGlobal::get_error_string(state, error)); wait_group.done(); return; } printf("; Workflow DNSResolver\n"); printf(";; HEADER opcode:%s status:%s id:%d\n", dns_opcode2str(resp->get_opcode()), dns_rcode2str(resp->get_rcode()), resp->get_id() ); printf(";; QUERY:%d ANSWER:%d AUTHORITY:%d ADDITIONAL:%d\n", resp->get_qdcount(), resp->get_ancount(), resp->get_nscount(), resp->get_arcount() ); printf("\n"); protocol::DnsResultCursor cursor(resp); if(resp->get_ancount() > 0) { cursor.reset_answer_cursor(); printf(";; ANSWER SECTION:\n"); show_result(cursor); } if(resp->get_nscount() > 0) { cursor.reset_authority_cursor(); printf(";; AUTHORITY SECTION\n"); show_result(cursor); } if(resp->get_arcount() > 0) { cursor.reset_additional_cursor(); printf(";; ADDITIONAL SECTION\n"); show_result(cursor); } wait_group.done(); } int main(int argc, char *argv[]) { int qtype = DNS_TYPE_A; const char *domain; if (argc == 1 || argc > 3) { fprintf(stderr, "USAGE: %s [query type]\n", argv[0]); return 1; } domain = argv[1]; if (argc == 3) { auto it = qtype_map.find(argv[2]); if (it != qtype_map.end()) qtype = it->second; } std::string url = "dns://119.29.29.29"; WFDnsTask *task = WFTaskFactory::create_dns_task(url, 0, dns_callback); protocol::DnsRequest *req = task->get_req(); req->set_rd(1); req->set_question(domain, qtype, DNS_CLASS_IN); task->start(); wait_group.wait(); return 0; } workflow-0.11.8/tutorial/tutorial-18-redis_subscriber.cc000066400000000000000000000054201476003635400232560ustar00rootroot00000000000000#include #include #include #include #include #include "workflow/WFRedisSubscriber.h" #include "workflow/WFFacilities.h" #include "workflow/StringUtil.h" void extract(WFRedisSubscribeTask *task) { auto *resp = task->get_resp(); protocol::RedisValue value; resp->get_result(value); if (value.is_array()) { for (size_t i = 0; i < value.arr_size(); i++) { if (value[i].is_string()) std::cout << value[i].string_value(); else if (value[i].is_int()) std::cout << value[i].int_value(); else if (value[i].is_nil()) std::cout << "nil"; else std::cout << "Unexpected value in array!"; std::cout << "\n"; } } else std::cout << "Unexpected value!\n"; } int main(int argc, char *argv[]) { if (argc < 3) { std::cerr << argv[0] << " []..." << std::endl; exit(1); } std::string url = argv[1]; if (strncasecmp(argv[1], "redis://", 8) != 0 && strncasecmp(argv[1], "rediss://", 9) != 0) { url = "redis://" + url; } WFRedisSubscriber suber; if (suber.init(url) != 0) { std::cerr << "Subscriber init failed " << strerror(errno) << std::endl; exit(1); } std::vector channels; for (int i = 2; i < argc; i++) channels.push_back(argv[i]); WFFacilities::WaitGroup wg(1); bool finished = false; auto callback = [&](WFRedisSubscribeTask *task) { std::cout << "state = " << task->get_state() << ", error = " << task->get_error() << std::endl; finished = true; wg.done(); }; WFRedisSubscribeTask *task; task = suber.create_subscribe_task(channels, extract, callback); task->set_watch_timeout(1000000); task->start(); std::string line; while (!finished) { std::string cmd; std::vector params; if (std::getline(std::cin, line)) { if (line.empty()) continue; params = StringUtil::split_filter_empty(line, ' '); } if (finished) break; if (params.empty()) { task->unsubscribe(); task->punsubscribe(); break; } cmd = params[0]; params.erase(params.begin()); for (char &c : cmd) c = std::toupper(c); int ret; if (cmd == "SUBSCRIBE") ret = task->subscribe(params); else if (cmd == "UNSUBSCRIBE") ret = task->unsubscribe(params); else if (cmd == "PSUBSCRIBE") ret = task->psubscribe(params); else if (cmd == "PUNSUBSCRIBE") ret = task->punsubscribe(params); else if (cmd == "PING") { if (params.empty()) ret = task->ping(); else ret = task->ping(params[0]); } else if (cmd == "QUIT") ret = task->quit(); else { std::cerr << "Invalid command " << cmd << std::endl; ret = 0; } if (ret < 0) { std::cerr << "Send command failed " << strerror(errno) << std::endl; break; } } task->release(); wg.wait(); suber.deinit(); return 0; } workflow-0.11.8/tutorial/tutorial-19-dns_server.cc000066400000000000000000000045521476003635400221050ustar00rootroot00000000000000#include #include #include #include #include "workflow/WFDnsServer.h" #include "workflow/WFFacilities.h" void process(WFDnsTask *task) { protocol::DnsRequest *req = task->get_req(); protocol::DnsResponse *resp = task->get_resp(); std::string name = req->get_question_name(); int qtype = req->get_question_type(); int qclass = req->get_question_class(); int opcode = req->get_opcode(); printf("name:%s type:%s class:%s\n", name.c_str(), dns_type2str(qtype), dns_class2str(qclass)); if (opcode != 0) { resp->set_rcode(DNS_RCODE_NOT_IMPLEMENTED); return; } resp->set_rcode(DNS_RCODE_NO_ERROR); resp->set_aa(1); if (qtype == DNS_TYPE_A) { std::string cname = "cname.test"; resp->add_cname_record(DNS_ANSWER_SECTION, name.c_str(), DNS_CLASS_IN, 999, cname.c_str()); struct in_addr addr; inet_pton(AF_INET, "192.168.0.1", (void *)&addr); resp->add_a_record(DNS_ANSWER_SECTION, cname.c_str(), DNS_CLASS_IN, 600, &addr); inet_pton(AF_INET, "192.168.0.2", (void *)&addr); resp->add_a_record(DNS_ANSWER_SECTION, cname.c_str(), DNS_CLASS_IN, 600, &addr); } else if (qtype == DNS_TYPE_AAAA) { struct in6_addr addr; inet_pton(AF_INET6, "1234:5678:9abc:def0::", (void *)&addr); resp->add_aaaa_record(DNS_ANSWER_SECTION, name.c_str(), DNS_CLASS_IN, 600, &addr); } else if (qtype == DNS_TYPE_SOA) { const char *mname = "mname.test"; const char *rname = "rname.test"; resp->add_soa_record(DNS_ANSWER_SECTION, name.c_str(), DNS_CLASS_IN, 60, mname, rname, 123, 86400, 3600, 2592000, 7200); } else if (qtype == DNS_TYPE_TXT) { const char *raw_txt_data = "\x0dmy dns server\x0fyour dns server"; uint16_t data_len = 30; resp->add_raw_record(DNS_ANSWER_SECTION, name.c_str(), DNS_TYPE_TXT, DNS_CLASS_IN, 1200, raw_txt_data, data_len); } else { resp->set_rcode(DNS_RCODE_NOT_IMPLEMENTED); } } static WFFacilities::WaitGroup wait_group(1); void sig_handler(int signo) { wait_group.done(); } int main(int argc, char *argv[]) { unsigned short port; if (argc != 2) { fprintf(stderr, "USAGE: %s \n", argv[0]); exit(1); } signal(SIGINT, sig_handler); WFDnsServer server(process); port = atoi(argv[1]); if (server.start(port) == 0) { wait_group.wait(); server.stop(); } else { perror("Cannot start server"); } return 0; } workflow-0.11.8/tutorial/xmake.lua000066400000000000000000000025551476003635400171450ustar00rootroot00000000000000set_group("tutorial") set_default(false) if not is_plat("macosx") then add_ldflags("-lrt") end function all_examples() local res = {} for _, x in ipairs(os.files("*.cc")) do local item = {} local s = path.filename(x) if ((s == "upstream_unittest.cc" and not has_config("upstream")) or ((s == "tutorial-02-redis_cli.cc" or s == "tutorial-03-wget_to_redis.cc" or s == "tutorial-18-redis_subscriber.cc") and not has_config("redis")) or (s == "tutorial-12-mysql_cli.cc" and not has_config("mysql")) or (s == "tutorial-14-consul_cli.cc" and not has_config("consul")) or (s == "tutorial-13-kafka_cli.cc")) then else table.insert(item, s:sub(1, #s - 3)) -- target table.insert(item, path.relative(x, ".")) -- source table.insert(res, item) end end return res end for _, example in ipairs(all_examples()) do target(example[1]) set_kind("binary") add_files(example[2]) add_deps("workflow") end target("tutorial-13-kafka_cli") if has_config("kafka") then set_kind("binary") add_files("tutorial-13-kafka_cli.cc") add_packages("zlib", "snappy", "zstd", "lz4") add_deps("wfkafka") else set_kind("phony") end includes("tutorial-10-user_defined_protocol", "tutorial-16-graceful_restart") workflow-0.11.8/workflow-config.cmake.in000066400000000000000000000005441476003635400202120ustar00rootroot00000000000000@PACKAGE_INIT@ set(WORKFLOW_VERSION "@workflow_VERSION@") set_and_check(WORKFLOW_INCLUDE_DIR "@PACKAGE_CONFIG_INC_DIR@") set_and_check(WORKFLOW_LIB_DIR "@PACKAGE_CONFIG_LIB_DIR@") if (EXISTS "${CMAKE_CURRENT_LIST_DIR}/workflow-targets.cmake") include ("${CMAKE_CURRENT_LIST_DIR}/workflow-targets.cmake") endif () check_required_components(workflow) workflow-0.11.8/xmake.lua000066400000000000000000000027041476003635400152760ustar00rootroot00000000000000set_project("workflow") set_version("0.11.8") option("workflow_inc", {description = "workflow inc", default = "$(projectdir)/_include"}) option("workflow_lib", {description = "workflow lib", default = "$(projectdir)/_lib"}) option("kafka", {description = "build kafka component", default = false}) option("consul", {description = "build consul component", default = true}) option("mysql", {description = "build mysql component", default = true}) option("redis", {description = "build redis component", default = true}) option("upstream", {description = "build upstream component", default = true}) option("memcheck", {description = "valgrind memcheck", default = false}) if is_mode("release") then set_optimize("faster") set_strip("all") elseif is_mode("debug") then set_symbols("debug") set_optimize("none") end set_languages("gnu90", "c++11") set_warnings("all") set_exceptions("no-cxx") add_requires("openssl") add_packages("openssl") add_syslinks("pthread") if has_config("kafka") then add_requires("snappy", "lz4", "zstd", "zlib") end add_includedirs(get_config("workflow_inc")) add_includedirs(path.join(get_config("workflow_inc"), "workflow")) set_config("buildir", "build.xmake") add_cflags("-fPIC", "-pipe") add_cxxflags("-fPIC", "-pipe", "-Wno-invalid-offsetof") if (is_plat("macosx")) then add_cxxflags("-Wno-deprecated-declarations") end includes("src", "test", "benchmark", "tutorial")