aboutsummaryrefslogtreecommitdiff
path: root/spark-standalone-agent/src/main/java/me/lucko/spark/standalone/StandaloneSparkAgent.java
blob: 2820b85d1db0531a398a3b48c364d3d0ab2c7e2b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/*
 * This file is part of spark.
 *
 *  Copyright (c) lucko (Luck) <luck@lucko.me>
 *  Copyright (c) contributors
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

package me.lucko.spark.standalone;

import com.sun.tools.attach.VirtualMachine;
import com.sun.tools.attach.VirtualMachineDescriptor;

import java.lang.instrument.Instrumentation;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class StandaloneSparkAgent {

    // Entry point when the agent is run as a normal jar
    public static void main(String[] args) {
        if (args.length == 0) {
            System.err.println("Usage: java -jar spark-standalone-agent.jar <pid> [args...]");

            List<VirtualMachineDescriptor> vms = VirtualMachine.list();
            if (vms.isEmpty()) {
                return;
            }

            System.out.println("Current JVM processes:");
            for (VirtualMachineDescriptor vm : vms) {
                System.out.println("  pid=" + vm.id() + " (" + vm.displayName() + ")");
            }

            return;
        }

        try {
            VirtualMachine vm = VirtualMachine.attach(args[0]);
            String agentPath = StandaloneSparkAgent.class.getProtectionDomain().getCodeSource().getLocation().getPath();
            String arguments = String.join(",", Arrays.copyOfRange(args, 1, args.length));
            vm.loadAgent(agentPath, arguments);
            System.out.println("[spark] Agent loaded successfully.");
            vm.detach();
        } catch (Throwable e) {
            System.err.println("Failed to attach agent to process " + args[0]);
            e.printStackTrace(System.err);
        }
    }

    // Entry point when the agent is loaded via -javaagent
    public static void premain(String agentArgs, Instrumentation instrumentation) {
        System.out.println("[spark] Loading standalone agent... (premain)");
        init(agentArgs, instrumentation);
    }

    // Entry point when the agent is loaded via VirtualMachine#loadAgent
    public static void agentmain(String agentArgs, Instrumentation instrumentation) {
        System.out.println("[spark] Loading standalone agent... (agentmain)");
        init(agentArgs, instrumentation);
    }

    private static void init(String agentArgs, Instrumentation instrumentation) {
        try {
            Map<String, String> arguments = new HashMap<>();
            if (agentArgs == null) {
                agentArgs = "";
            }
            for (String arg : agentArgs.split(",")) {
                if (arg.contains("=")) {
                    String[] parts = arg.split("=", 2);
                    arguments.put(parts[0], parts[1]);
                } else {
                    arguments.put(arg, "true");
                }
            }
            new StandaloneSparkPlugin(instrumentation, arguments);
        } catch (Throwable e) {
            System.err.println("[spark] Loading failed :(");
            e.printStackTrace(System.err);
        }
    }

}