掘金 后端 ( ) • 2024-06-24 10:27

一、背景

Hi,大家好,我是大鱼七成饱。最近对接其他部门,遇到个必须用java调用python方法的事情。起因是算法有个平台,训练完直接生成微服务,但是呢,生成的代码不能改(作为后端也表示理解😢)。这个服务很牛,牛的地方在于rpc调用不通,http只支持python调用(numpy解析返回值),http限制特定语言是什么鬼,作为老后端开发,想把那个平台开发这拉出来,xx五分钟(转成json返回就能解决的事),此刻心情见下图。看了看钱包余额,算了,混口饭吃,还是想想解决方案吧。

image-20240623161049774.png

二、解决方案调研

image-20240623161933595.png

为了服务器成本考虑(大环境都在广进),不想为了一个小方法搭建python服务,因此遍历baidu、谷歌和闭关思考后,找到了这三个方向的解决方案:

  • 支持在代码层面运行java的Jython或者底层支持python运行的Graalvm
  • 本机或者镜像安装python环境
  • 用java实现python方法的协议

1、代码层面运行java的Jython或者底层支持python运行的Graalvm

  • Jython

官方网站:https://www.jython.org/

目前支持到 Jython 2.7.x, python3还在开发中。 这款工具支持嵌入式的方式执行Java脚本。使用的话,先引入pom

  <dependencies>
        <dependency>
            <groupId>org.python</groupId>
            <artifactId>jython-standalone</artifactId>
            <!--指定Python的版本-->
            <version>2.7.0</version>
        </dependency>
    </dependencies>

然后调用

image-20240623165334758.png

这个组件对于一般的功能是可以的,这次的需求是调用numpy的函数,numpy底层是C代码,官方明确表示不支持,所以没有采纳这个方案

  • GraalVM
import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.Source;
import org.graalvm.polyglot.Value;

public class Main {
    public static void main(String[] args) {
        Context context = Context.newBuilder("python").build();
        Source source = Source.newBuilder("python", "print('Hello, World!')", "hello.py").build();
        context.eval(source);

        Value result = context.getBindings("python").getMember("result");
        System.out.println("Result: " + result);
    }
}

image-20240623175521850.png

官方表示是支持numpy运行,但是版本太高,升级到22版本,需要适配现有代码,时间来不及

2、调用本机或者docker中的python命令

  • java自带的API
@Test
public void givenPythonScript_whenPythonProcessInvoked_thenSuccess() throws Exception {
    ProcessBuilder processBuilder = new ProcessBuilder("python", resolvePythonScriptPath("hello.py"));
    processBuilder.redirectErrorStream(true);

    Process process = processBuilder.start();
    List<String> results = readProcessOutput(process.getInputStream());

    assertThat("Results should not be empty", results, is(not(empty())));
    assertThat("Results should contain output of script: ", results, hasItem(
      containsString("Hello Baeldung Readers!!")));

    int exitCode = process.waitFor();
    assertEquals("No errors should be detected", 0, exitCode);
}
  • apache提供的调用工具

    <dependency>
        <groupId>org.apache.commons</groupId>
        <artifactId>commons-exec</artifactId>
        <version>1.3</version>
    </dependency>
    
    @Test
    public void givenPythonScript_whenPythonProcessExecuted_thenSuccess() 
      throws ExecuteException, IOException {
        String line = "python " + resolvePythonScriptPath("hello.py");
        CommandLine cmdLine = CommandLine.parse(line);
            
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
        PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream);
            
        DefaultExecutor executor = new DefaultExecutor();
        executor.setStreamHandler(streamHandler);
    
        int exitCode = executor.execute(cmdLine);
        assertEquals("No errors should be detected", 0, exitCode);
        assertEquals("Should contain script output: ", "Hello Baeldung Readers!!", outputStream.toString()
          .trim());
    }
    

    这个方案属于比较简单的解决方案,适用性也较好。但是在我们这个场景不适合,需要修改基础镜像,当时找了7、8个人才找到负责人,还不让改。。。。

image-20240623175009147.png

3、用java实现python的协议

研究了一下python方法(<https://github.com/triton-inference-server/python_backend/blob/main/src/resources/triton_python_backend_utils.py),代码如下:>

    import numpy as np
    import struct
    import pprint
    def deserialize_bytes_tensor(encoded_tensor):
        """
        Deserializes an encoded bytes tensor into an
        numpy array of dtype of python objects
        Parameters
        ----------
        encoded_tensor : bytes
            The encoded bytes tensor where each element
            has its length in first 4 bytes followed by
            the content
        Returns
        -------
        string_tensor : np.array
            The 1-D numpy array of type object containing the
            deserialized bytes in 'C' order.
        """
        strs = list()
        offset = 0
        val_buf = encoded_tensor
        while offset < len(val_buf):
            l = struct.unpack_from("<I", val_buf, offset)[0]
            offset += 4
            sb = struct.unpack_from("<{}s".format(l), val_buf, offset)[0]
            offset += l
            strs.append(sb)
        return (np.array(strs, dtype=np.object_))

研究了下,内容超前面有4个长度的offset,跳过就可以。因此java实现的时候用正则切割了下。这里就不演示了,不加解释的话,估计只能自己看懂。这个方案的问题是只能在这个特定场景使用,后面扩展其他方法,就要重新研究。老话说拔出浓来才是好膏药,考虑时间成本和后场景的不确定,最后采用了这个方案。

image-20240623174916880.png

三、总结

java调用python属于有一堆解决方案的问题,本文也探讨了适用的场景。如果可以选,从源头解决是最好的,但是技术方案还要考虑现实情况,成本和交付时间等,综合考虑合适的才是最好的。

image-20240623175237644.png

本人公众号大鱼七成饱,历史文章会在上面同步