diff options
Diffstat (limited to 'spark-common/src/main/java/me/lucko')
38 files changed, 1931 insertions, 388 deletions
| diff --git a/spark-common/src/main/java/me/lucko/spark/common/SparkPlatform.java b/spark-common/src/main/java/me/lucko/spark/common/SparkPlatform.java index dae04ff..84f435a 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/SparkPlatform.java +++ b/spark-common/src/main/java/me/lucko/spark/common/SparkPlatform.java @@ -23,6 +23,8 @@ package me.lucko.spark.common;  import com.google.common.collect.ImmutableList;  import com.google.common.collect.ImmutableMap; +import me.lucko.bytesocks.client.BytesocksClient; +import me.lucko.bytesocks.client.BytesocksClientFactory;  import me.lucko.spark.common.activitylog.ActivityLog;  import me.lucko.spark.common.api.SparkApi;  import me.lucko.spark.common.command.Arguments; @@ -43,6 +45,7 @@ import me.lucko.spark.common.monitor.memory.GarbageCollectorStatistics;  import me.lucko.spark.common.monitor.net.NetworkMonitor;  import me.lucko.spark.common.monitor.ping.PingStatistics;  import me.lucko.spark.common.monitor.ping.PlayerPingProvider; +import me.lucko.spark.common.monitor.tick.SparkTickStatistics;  import me.lucko.spark.common.monitor.tick.TickStatistics;  import me.lucko.spark.common.platform.PlatformStatisticsProvider;  import me.lucko.spark.common.sampler.BackgroundSamplerManager; @@ -53,6 +56,7 @@ import me.lucko.spark.common.tick.TickReporter;  import me.lucko.spark.common.util.BytebinClient;  import me.lucko.spark.common.util.Configuration;  import me.lucko.spark.common.util.TemporaryFiles; +import me.lucko.spark.common.ws.TrustedKeyStore;  import net.kyori.adventure.text.Component;  import net.kyori.adventure.text.event.ClickEvent; @@ -95,6 +99,8 @@ public class SparkPlatform {      private final Configuration configuration;      private final String viewerUrl;      private final BytebinClient bytebinClient; +    private final BytesocksClient bytesocksClient; +    private final TrustedKeyStore trustedKeyStore;      private final boolean disableResponseBroadcast;      private final List<CommandModule> commandModules;      private final List<Command> commands; @@ -118,8 +124,12 @@ public class SparkPlatform {          this.configuration = new Configuration(this.plugin.getPluginDirectory().resolve("config.json"));          this.viewerUrl = this.configuration.getString("viewerUrl", "https://spark.lucko.me/"); -        String bytebinUrl = this.configuration.getString("bytebinUrl", "https://bytebin.lucko.me/"); +        String bytebinUrl = this.configuration.getString("bytebinUrl", "https://spark-usercontent.lucko.me/"); +        String bytesocksHost = this.configuration.getString("bytesocksHost", "spark-usersockets.lucko.me"); +          this.bytebinClient = new BytebinClient(bytebinUrl, "spark-plugin"); +        this.bytesocksClient = BytesocksClientFactory.newClient(bytesocksHost, "spark-plugin"); +        this.trustedKeyStore = new TrustedKeyStore(this.configuration);          this.disableResponseBroadcast = this.configuration.getBoolean("disableResponseBroadcast", false); @@ -144,9 +154,13 @@ public class SparkPlatform {          this.samplerContainer = new SamplerContainer();          this.backgroundSamplerManager = new BackgroundSamplerManager(this, this.configuration); +        TickStatistics tickStatistics = plugin.createTickStatistics();          this.tickHook = plugin.createTickHook();          this.tickReporter = plugin.createTickReporter(); -        this.tickStatistics = this.tickHook != null || this.tickReporter != null ? new TickStatistics() : null; +        if (tickStatistics == null && (this.tickHook != null || this.tickReporter != null)) { +            tickStatistics = new SparkTickStatistics(); +        } +        this.tickStatistics = tickStatistics;          PlayerPingProvider pingProvider = plugin.createPlayerPingProvider();          this.pingStatistics = pingProvider != null ? new PingStatistics(pingProvider) : null; @@ -159,12 +173,12 @@ public class SparkPlatform {              throw new RuntimeException("Platform has already been enabled!");          } -        if (this.tickHook != null) { -            this.tickHook.addCallback(this.tickStatistics); +        if (this.tickHook != null && this.tickStatistics instanceof SparkTickStatistics) { +            this.tickHook.addCallback((TickHook.Callback) this.tickStatistics);              this.tickHook.start();          } -        if (this.tickReporter != null) { -            this.tickReporter.addCallback(this.tickStatistics); +        if (this.tickReporter != null&& this.tickStatistics instanceof SparkTickStatistics) { +            this.tickReporter.addCallback((TickReporter.Callback) this.tickStatistics);              this.tickReporter.start();          }          if (this.pingStatistics != null) { @@ -228,6 +242,14 @@ public class SparkPlatform {          return this.bytebinClient;      } +    public BytesocksClient getBytesocksClient() { +        return this.bytesocksClient; +    } + +    public TrustedKeyStore getTrustedKeyStore() { +        return this.trustedKeyStore; +    } +      public boolean shouldBroadcastResponse() {          return !this.disableResponseBroadcast;      } diff --git a/spark-common/src/main/java/me/lucko/spark/common/SparkPlugin.java b/spark-common/src/main/java/me/lucko/spark/common/SparkPlugin.java index b7aef2a..a3bdceb 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/SparkPlugin.java +++ b/spark-common/src/main/java/me/lucko/spark/common/SparkPlugin.java @@ -23,6 +23,7 @@ package me.lucko.spark.common;  import me.lucko.spark.api.Spark;  import me.lucko.spark.common.command.sender.CommandSender;  import me.lucko.spark.common.monitor.ping.PlayerPingProvider; +import me.lucko.spark.common.monitor.tick.TickStatistics;  import me.lucko.spark.common.platform.MetadataProvider;  import me.lucko.spark.common.platform.PlatformInfo;  import me.lucko.spark.common.platform.serverconfig.ServerConfigProvider; @@ -128,6 +129,18 @@ public interface SparkPlugin {      }      /** +     * Creates tick statistics for the platform, if supported. +     * +     * <p>Spark is able to provide a default implementation for platforms that +     * provide a {@link TickHook} and {@link TickReporter}.</p> +     * +     * @return a new tick statistics instance +     */ +    default TickStatistics createTickStatistics() { +        return null; +    } + +    /**       * Creates a class source lookup function.       *       * @return the class source lookup function diff --git a/spark-common/src/main/java/me/lucko/spark/common/api/SparkApi.java b/spark-common/src/main/java/me/lucko/spark/common/api/SparkApi.java index 5b1ec2b..9e4eee4 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/api/SparkApi.java +++ b/spark-common/src/main/java/me/lucko/spark/common/api/SparkApi.java @@ -151,6 +151,8 @@ public class SparkApi implements Spark {                          return stats.duration10Sec();                      case MINUTES_1:                          return stats.duration1Min(); +                    case MINUTES_5: +                        return stats.duration5Min();                      default:                          throw new AssertionError(window);                  } diff --git a/spark-common/src/main/java/me/lucko/spark/common/command/modules/HeapAnalysisModule.java b/spark-common/src/main/java/me/lucko/spark/common/command/modules/HeapAnalysisModule.java index 5bd62a8..6ac3b2f 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/command/modules/HeapAnalysisModule.java +++ b/spark-common/src/main/java/me/lucko/spark/common/command/modules/HeapAnalysisModule.java @@ -32,6 +32,7 @@ import me.lucko.spark.common.heapdump.HeapDump;  import me.lucko.spark.common.heapdump.HeapDumpSummary;  import me.lucko.spark.common.util.Compression;  import me.lucko.spark.common.util.FormatUtil; +import me.lucko.spark.common.util.MediaTypes;  import me.lucko.spark.proto.SparkHeapProtos;  import net.kyori.adventure.text.event.ClickEvent; @@ -52,7 +53,6 @@ import static net.kyori.adventure.text.format.NamedTextColor.GREEN;  import static net.kyori.adventure.text.format.NamedTextColor.RED;  public class HeapAnalysisModule implements CommandModule { -    private static final String SPARK_HEAP_MEDIA_TYPE = "application/x-spark-heap";      @Override      public void registerCommands(Consumer<Command> consumer) { @@ -97,7 +97,7 @@ public class HeapAnalysisModule implements CommandModule {              saveToFile = true;          } else {              try { -                String key = platform.getBytebinClient().postContent(output, SPARK_HEAP_MEDIA_TYPE).key(); +                String key = platform.getBytebinClient().postContent(output, MediaTypes.SPARK_HEAP_MEDIA_TYPE).key();                  String url = platform.getViewerUrl() + key;                  resp.broadcastPrefixed(text("Heap dump summmary output:", GOLD)); diff --git a/spark-common/src/main/java/me/lucko/spark/common/command/modules/SamplerModule.java b/spark-common/src/main/java/me/lucko/spark/common/command/modules/SamplerModule.java index cd00f0d..27e790f 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/command/modules/SamplerModule.java +++ b/spark-common/src/main/java/me/lucko/spark/common/command/modules/SamplerModule.java @@ -22,6 +22,7 @@ package me.lucko.spark.common.command.modules;  import com.google.common.collect.Iterables; +import me.lucko.bytesocks.client.BytesocksClient;  import me.lucko.spark.common.SparkPlatform;  import me.lucko.spark.common.activitylog.Activity;  import me.lucko.spark.common.command.Arguments; @@ -33,6 +34,7 @@ import me.lucko.spark.common.command.tabcomplete.CompletionSupplier;  import me.lucko.spark.common.command.tabcomplete.TabCompleter;  import me.lucko.spark.common.sampler.Sampler;  import me.lucko.spark.common.sampler.SamplerBuilder; +import me.lucko.spark.common.sampler.SamplerMode;  import me.lucko.spark.common.sampler.ThreadDumper;  import me.lucko.spark.common.sampler.ThreadGrouper;  import me.lucko.spark.common.sampler.async.AsyncSampler; @@ -40,7 +42,9 @@ import me.lucko.spark.common.sampler.node.MergeMode;  import me.lucko.spark.common.sampler.source.ClassSourceLookup;  import me.lucko.spark.common.tick.TickHook;  import me.lucko.spark.common.util.FormatUtil; +import me.lucko.spark.common.util.MediaTypes;  import me.lucko.spark.common.util.MethodDisambiguator; +import me.lucko.spark.common.ws.ViewerSocket;  import me.lucko.spark.proto.SparkSamplerProtos;  import net.kyori.adventure.text.Component; @@ -67,7 +71,6 @@ import static net.kyori.adventure.text.format.NamedTextColor.RED;  import static net.kyori.adventure.text.format.NamedTextColor.WHITE;  public class SamplerModule implements CommandModule { -    private static final String SPARK_SAMPLER_MEDIA_TYPE = "application/x-spark-sampler";      @Override      public void registerCommands(Consumer<Command> consumer) { @@ -75,11 +78,13 @@ public class SamplerModule implements CommandModule {                  .aliases("profiler", "sampler")                  .allowSubCommand(true)                  .argumentUsage("info", "", null) +                .argumentUsage("open", "", null)                  .argumentUsage("start", "timeout", "timeout seconds")                  .argumentUsage("start", "thread *", null)                  .argumentUsage("start", "thread", "thread name")                  .argumentUsage("start", "only-ticks-over", "tick length millis")                  .argumentUsage("start", "interval", "interval millis") +                .argumentUsage("start", "alloc", null)                  .argumentUsage("stop", "", null)                  .argumentUsage("cancel", "", null)                  .executor(this::profiler) @@ -94,14 +99,14 @@ public class SamplerModule implements CommandModule {                          }                          if (subCommand.equals("start")) {                              opts = new ArrayList<>(Arrays.asList("--timeout", "--regex", "--combine-all", -                                    "--not-combined", "--interval", "--only-ticks-over", "--force-java-sampler")); +                                    "--not-combined", "--interval", "--only-ticks-over", "--force-java-sampler", "--alloc", "--alloc-live-only"));                              opts.removeAll(arguments);                              opts.add("--thread"); // allowed multiple times                          }                      }                      return TabCompleter.create() -                            .at(0, CompletionSupplier.startsWith(Arrays.asList("info", "start", "stop", "cancel"))) +                            .at(0, CompletionSupplier.startsWith(Arrays.asList("info", "start", "open", "stop", "cancel")))                              .from(1, CompletionSupplier.startsWith(opts))                              .complete(arguments);                  }) @@ -117,6 +122,16 @@ public class SamplerModule implements CommandModule {              return;          } +        if (subCommand.equals("open") || arguments.boolFlag("open")) { +            profilerOpen(platform, sender, resp, arguments); +            return; +        } + +        if (subCommand.equals("trust-viewer") || arguments.boolFlag("trust-viewer")) { +            profilerTrustViewer(platform, sender, resp, arguments); +            return; +        } +          if (subCommand.equals("cancel") || arguments.boolFlag("cancel")) {              profilerCancel(platform, resp);              return; @@ -166,9 +181,12 @@ public class SamplerModule implements CommandModule {                      "Consider setting a timeout value over 30 seconds."));          } -        double intervalMillis = arguments.doubleFlag("interval"); -        if (intervalMillis <= 0) { -            intervalMillis = 4; +        SamplerMode mode = arguments.boolFlag("alloc") ? SamplerMode.ALLOCATION : SamplerMode.EXECUTION; +        boolean allocLiveOnly = arguments.boolFlag("alloc-live-only"); + +        double interval = arguments.doubleFlag("interval"); +        if (interval <= 0) { +            interval = mode.defaultInterval();          }          boolean ignoreSleeping = arguments.boolFlag("ignore-sleeping"); @@ -213,23 +231,33 @@ public class SamplerModule implements CommandModule {          resp.broadcastPrefixed(text("Starting a new profiler, please wait..."));          SamplerBuilder builder = new SamplerBuilder(); +        builder.mode(mode);          builder.threadDumper(threadDumper);          builder.threadGrouper(threadGrouper);          if (timeoutSeconds != -1) {              builder.completeAfter(timeoutSeconds, TimeUnit.SECONDS);          } -        builder.samplingInterval(intervalMillis); +        builder.samplingInterval(interval);          builder.ignoreSleeping(ignoreSleeping);          builder.ignoreNative(ignoreNative);          builder.forceJavaSampler(forceJavaSampler); +        builder.allocLiveOnly(allocLiveOnly);          if (ticksOver != -1) {              builder.ticksOver(ticksOver, tickHook);          } -        Sampler sampler = builder.start(platform); + +        Sampler sampler; +        try { +            sampler = builder.start(platform); +        } catch (UnsupportedOperationException e) { +            resp.replyPrefixed(text(e.getMessage(), RED)); +            return; +        } +          platform.getSamplerContainer().setActiveSampler(sampler);          resp.broadcastPrefixed(text() -                .append(text("Profiler is now running!", GOLD)) +                .append(text((mode == SamplerMode.ALLOCATION ? "Allocation Profiler" : "Profiler") + " is now running!", GOLD))                  .append(space())                  .append(text("(" + (sampler instanceof AsyncSampler ? "async" : "built-in java") + ")", DARK_GRAY))                  .build() @@ -239,6 +267,8 @@ public class SamplerModule implements CommandModule {              resp.broadcastPrefixed(text("It will run in the background until it is stopped by an admin."));              resp.broadcastPrefixed(text("To stop the profiler and upload the results, run:"));              resp.broadcastPrefixed(cmdPrompt("/" + platform.getPlugin().getCommandName() + " profiler stop")); +            resp.broadcastPrefixed(text("To view the profiler while it's running, run:")); +            resp.broadcastPrefixed(cmdPrompt("/" + platform.getPlugin().getCommandName() + " profiler open"));          } else {              resp.broadcastPrefixed(text("The results will be automatically returned after the profiler has been running for " + FormatUtil.formatSeconds(timeoutSeconds) + "."));          } @@ -258,13 +288,11 @@ public class SamplerModule implements CommandModule {          // await the result          if (timeoutSeconds != -1) { -            String comment = Iterables.getFirst(arguments.stringFlag("comment"), null); -            MethodDisambiguator methodDisambiguator = new MethodDisambiguator(); -            MergeMode mergeMode = arguments.boolFlag("separate-parent-calls") ? MergeMode.separateParentCalls(methodDisambiguator) : MergeMode.sameMethod(methodDisambiguator); +            Sampler.ExportProps exportProps = getExportProps(platform, resp, arguments);              boolean saveToFile = arguments.boolFlag("save-to-file");              future.thenAcceptAsync(s -> {                  resp.broadcastPrefixed(text("The active profiler has completed! Uploading results...")); -                handleUpload(platform, resp, s, comment, mergeMode, saveToFile); +                handleUpload(platform, resp, s, exportProps, saveToFile);              });          }      } @@ -291,6 +319,9 @@ public class SamplerModule implements CommandModule {                  resp.replyPrefixed(text("So far, it has profiled for " + FormatUtil.formatSeconds(runningTime) + "."));              } +            resp.replyPrefixed(text("To view the profiler while it's running, run:")); +            resp.replyPrefixed(cmdPrompt("/" + platform.getPlugin().getCommandName() + " profiler open")); +              long timeout = sampler.getAutoEndTime();              if (timeout == -1) {                  resp.replyPrefixed(text("To stop the profiler and upload the results, run:")); @@ -305,6 +336,48 @@ public class SamplerModule implements CommandModule {          }      } +    private void profilerOpen(SparkPlatform platform, CommandSender sender, CommandResponseHandler resp, Arguments arguments) { +        BytesocksClient bytesocksClient = platform.getBytesocksClient(); +        if (bytesocksClient == null) { +            resp.replyPrefixed(text("The live viewer is only supported on Java 11 or newer.", RED)); +            return; +        } + +        Sampler sampler = platform.getSamplerContainer().getActiveSampler(); +        if (sampler == null) { +            resp.replyPrefixed(text("The profiler isn't running!")); +            resp.replyPrefixed(text("To start a new one, run:")); +            resp.replyPrefixed(cmdPrompt("/" + platform.getPlugin().getCommandName() + " profiler start")); +            return; +        } + +        Sampler.ExportProps exportProps = getExportProps(platform, resp, arguments); +        handleOpen(platform, bytesocksClient, resp, sampler, exportProps); +    } + +    private void profilerTrustViewer(SparkPlatform platform, CommandSender sender, CommandResponseHandler resp, Arguments arguments) { +        Set<String> ids = arguments.stringFlag("id"); +        if (ids.isEmpty()) { +            resp.replyPrefixed(text("Please provide a client id with '--id <client id>'.")); +            return; +        } + +        for (String id : ids) { +            boolean success = platform.getTrustedKeyStore().trustPendingKey(id); +            if (success) { +                Sampler sampler = platform.getSamplerContainer().getActiveSampler(); +                if (sampler != null) { +                    for (ViewerSocket socket : sampler.getAttachedSockets()) { +                        socket.sendClientTrustedMessage(id); +                    } +                } +                resp.replyPrefixed(text("Client connected to the viewer using id '" + id + "' is now trusted.")); +            } else { +                resp.replyPrefixed(text("Unable to find pending client with id '" + id + "'.")); +            } +        } +    } +      private void profilerCancel(SparkPlatform platform, CommandResponseHandler resp) {          Sampler sampler = platform.getSamplerContainer().getActiveSampler();          if (sampler == null) { @@ -331,10 +404,8 @@ public class SamplerModule implements CommandModule {                  resp.broadcastPrefixed(text("Stopping the profiler & uploading results, please wait..."));              } -            String comment = Iterables.getFirst(arguments.stringFlag("comment"), null); -            MethodDisambiguator methodDisambiguator = new MethodDisambiguator(); -            MergeMode mergeMode = arguments.boolFlag("separate-parent-calls") ? MergeMode.separateParentCalls(methodDisambiguator) : MergeMode.sameMethod(methodDisambiguator); -            handleUpload(platform, resp, sampler, comment, mergeMode, saveToFile); +            Sampler.ExportProps exportProps = getExportProps(platform, resp, arguments); +            handleUpload(platform, resp, sampler, exportProps, saveToFile);              // if the previous sampler was running in the background, create a new one              if (platform.getBackgroundSamplerManager().restartBackgroundSampler()) { @@ -347,15 +418,15 @@ public class SamplerModule implements CommandModule {          }      } -    private void handleUpload(SparkPlatform platform, CommandResponseHandler resp, Sampler sampler, String comment, MergeMode mergeMode, boolean saveToFileFlag) { -        SparkSamplerProtos.SamplerData output = sampler.toProto(platform, resp.sender(), comment, mergeMode, ClassSourceLookup.create(platform)); +    private void handleUpload(SparkPlatform platform, CommandResponseHandler resp, Sampler sampler, Sampler.ExportProps exportProps, boolean saveToFileFlag) { +        SparkSamplerProtos.SamplerData output = sampler.toProto(platform, exportProps);          boolean saveToFile = false;          if (saveToFileFlag) {              saveToFile = true;          } else {              try { -                String key = platform.getBytebinClient().postContent(output, SPARK_SAMPLER_MEDIA_TYPE).key(); +                String key = platform.getBytebinClient().postContent(output, MediaTypes.SPARK_SAMPLER_MEDIA_TYPE).key();                  String url = platform.getViewerUrl() + key;                  resp.broadcastPrefixed(text("Profiler stopped & upload complete!", GOLD)); @@ -391,6 +462,45 @@ public class SamplerModule implements CommandModule {          }      } +    private void handleOpen(SparkPlatform platform, BytesocksClient bytesocksClient, CommandResponseHandler resp, Sampler sampler, Sampler.ExportProps exportProps) { +        try { +            ViewerSocket socket = new ViewerSocket(platform, bytesocksClient, exportProps); +            sampler.attachSocket(socket); +            exportProps.channelInfo(socket.getPayload()); + +            SparkSamplerProtos.SamplerData data = sampler.toProto(platform, exportProps); + +            String key = platform.getBytebinClient().postContent(data, MediaTypes.SPARK_SAMPLER_MEDIA_TYPE, "live").key(); +            String url = platform.getViewerUrl() + key; + +            resp.broadcastPrefixed(text("Profiler live viewer:", GOLD)); +            resp.broadcast(text() +                    .content(url) +                    .color(GRAY) +                    .clickEvent(ClickEvent.openUrl(url)) +                    .build() +            ); + +            platform.getActivityLog().addToLog(Activity.urlActivity(resp.sender(), System.currentTimeMillis(), "Profiler (live)", url)); +        } catch (Exception e) { +            resp.replyPrefixed(text("An error occurred whilst opening the live profiler.", RED)); +            e.printStackTrace(); +        } +    } + +    private Sampler.ExportProps getExportProps(SparkPlatform platform, CommandResponseHandler resp, Arguments arguments) { +        return new Sampler.ExportProps() +                .creator(resp.sender().toData()) +                .comment(Iterables.getFirst(arguments.stringFlag("comment"), null)) +                .mergeMode(() -> { +                    MethodDisambiguator methodDisambiguator = new MethodDisambiguator(); +                    return arguments.boolFlag("separate-parent-calls") +                            ? MergeMode.separateParentCalls(methodDisambiguator) +                            : MergeMode.sameMethod(methodDisambiguator); +                }) +                .classSourceLookup(() -> ClassSourceLookup.create(platform)); +    } +      private static Component cmdPrompt(String cmd) {          return text()                  .append(text("  ")) diff --git a/spark-common/src/main/java/me/lucko/spark/common/heapdump/HeapDumpSummary.java b/spark-common/src/main/java/me/lucko/spark/common/heapdump/HeapDumpSummary.java index c0980e7..eaedd31 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/heapdump/HeapDumpSummary.java +++ b/spark-common/src/main/java/me/lucko/spark/common/heapdump/HeapDumpSummary.java @@ -130,7 +130,7 @@ public final class HeapDumpSummary {                  .setPlatformMetadata(platform.getPlugin().getPlatformInfo().toData().toProto())                  .setCreator(creator.toData().toProto());          try { -            metadata.setPlatformStatistics(platform.getStatisticsProvider().getPlatformStatistics(null)); +            metadata.setPlatformStatistics(platform.getStatisticsProvider().getPlatformStatistics(null, true));          } catch (Exception e) {              e.printStackTrace();          } diff --git a/spark-common/src/main/java/me/lucko/spark/common/monitor/MonitoringExecutor.java b/spark-common/src/main/java/me/lucko/spark/common/monitor/MonitoringExecutor.java index 635ae20..cbacebf 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/monitor/MonitoringExecutor.java +++ b/spark-common/src/main/java/me/lucko/spark/common/monitor/MonitoringExecutor.java @@ -20,6 +20,8 @@  package me.lucko.spark.common.monitor; +import me.lucko.spark.common.util.SparkThreadFactory; +  import java.util.concurrent.Executors;  import java.util.concurrent.ScheduledExecutorService; @@ -29,7 +31,8 @@ public enum MonitoringExecutor {      /** The executor used to monitor & calculate rolling averages. */      public static final ScheduledExecutorService INSTANCE = Executors.newSingleThreadScheduledExecutor(r -> {          Thread thread = Executors.defaultThreadFactory().newThread(r); -        thread.setName("spark-monitor"); +        thread.setName("spark-monitoring-thread"); +        thread.setUncaughtExceptionHandler(SparkThreadFactory.EXCEPTION_HANDLER);          thread.setDaemon(true);          return thread;      }); diff --git a/spark-common/src/main/java/me/lucko/spark/common/monitor/tick/SparkTickStatistics.java b/spark-common/src/main/java/me/lucko/spark/common/monitor/tick/SparkTickStatistics.java new file mode 100644 index 0000000..5877cbe --- /dev/null +++ b/spark-common/src/main/java/me/lucko/spark/common/monitor/tick/SparkTickStatistics.java @@ -0,0 +1,197 @@ +/* + * This file is part of spark. + * + *  Copyright (c) lucko (Luck) <luck@lucko.me> + *  Copyright (c) contributors + * + *  This program is free software: you can redistribute it and/or modify + *  it under the terms of the GNU General Public License as published by + *  the Free Software Foundation, either version 3 of the License, or + *  (at your option) any later version. + * + *  This program is distributed in the hope that it will be useful, + *  but WITHOUT ANY WARRANTY; without even the implied warranty of + *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the + *  GNU General Public License for more details. + * + *  You should have received a copy of the GNU General Public License + *  along with this program.  If not, see <http://www.gnu.org/licenses/>. + */ + +package me.lucko.spark.common.monitor.tick; + +import me.lucko.spark.api.statistic.misc.DoubleAverageInfo; +import me.lucko.spark.common.tick.TickHook; +import me.lucko.spark.common.tick.TickReporter; +import me.lucko.spark.common.util.RollingAverage; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.concurrent.TimeUnit; + +/** + * Calculates the servers TPS (ticks per second) rate. + * + * <p>The code use to calculate the TPS is the same as the code used by the Minecraft server itself. + * This means that this class will output values the same as the /tps command.</p> + * + * <p>We calculate our own values instead of pulling them from the server for two reasons. Firstly, + * it's easier - pulling from the server requires reflection code on each of the platforms, we'd + * rather avoid that. Secondly, it allows us to generate rolling averages over a shorter period of + * time.</p> + */ +public class SparkTickStatistics implements TickHook.Callback, TickReporter.Callback, TickStatistics { + +    private static final long SEC_IN_NANO = TimeUnit.SECONDS.toNanos(1); +    private static final int TPS = 20; +    private static final int TPS_SAMPLE_INTERVAL = 20; +    private static final BigDecimal TPS_BASE = new BigDecimal(SEC_IN_NANO).multiply(new BigDecimal(TPS_SAMPLE_INTERVAL)); + +    private final TpsRollingAverage tps5Sec = new TpsRollingAverage(5); +    private final TpsRollingAverage tps10Sec = new TpsRollingAverage(10); +    private final TpsRollingAverage tps1Min = new TpsRollingAverage(60); +    private final TpsRollingAverage tps5Min = new TpsRollingAverage(60 * 5); +    private final TpsRollingAverage tps15Min = new TpsRollingAverage(60 * 15); +    private final TpsRollingAverage[] tpsAverages = {this.tps5Sec, this.tps10Sec, this.tps1Min, this.tps5Min, this.tps15Min}; + +    private boolean durationSupported = false; +    private final RollingAverage tickDuration10Sec = new RollingAverage(TPS * 10); +    private final RollingAverage tickDuration1Min = new RollingAverage(TPS * 60); +    private final RollingAverage tickDuration5Min = new RollingAverage(TPS * 60 * 5); +    private final RollingAverage[] tickDurationAverages = {this.tickDuration10Sec, this.tickDuration1Min, this.tickDuration5Min}; + +    private long last = 0; + +    @Override +    public boolean isDurationSupported() { +        return this.durationSupported; +    } + +    @Override +    public void onTick(int currentTick) { +        if (currentTick % TPS_SAMPLE_INTERVAL != 0) { +            return; +        } + +        long now = System.nanoTime(); + +        if (this.last == 0) { +            this.last = now; +            return; +        } + +        long diff = now - this.last; +        BigDecimal currentTps = TPS_BASE.divide(new BigDecimal(diff), 30, RoundingMode.HALF_UP); +        BigDecimal total = currentTps.multiply(new BigDecimal(diff)); + +        for (TpsRollingAverage rollingAverage : this.tpsAverages) { +            rollingAverage.add(currentTps, diff, total); +        } + +        this.last = now; +    } + +    @Override +    public void onTick(double duration) { +        this.durationSupported = true; +        BigDecimal decimal = new BigDecimal(duration); +        for (RollingAverage rollingAverage : this.tickDurationAverages) { +            rollingAverage.add(decimal); +        } +    } + +    @Override +    public double tps5Sec() { +        return this.tps5Sec.getAverage(); +    } + +    @Override +    public double tps10Sec() { +        return this.tps10Sec.getAverage(); +    } + +    @Override +    public double tps1Min() { +        return this.tps1Min.getAverage(); +    } + +    @Override +    public double tps5Min() { +        return this.tps5Min.getAverage(); +    } + +    @Override +    public double tps15Min() { +        return this.tps15Min.getAverage(); +    } + +    @Override +    public DoubleAverageInfo duration10Sec() { +        if (!this.durationSupported) { +            return null; +        } +        return this.tickDuration10Sec; +    } + +    @Override +    public DoubleAverageInfo duration1Min() { +        if (!this.durationSupported) { +            return null; +        } +        return this.tickDuration1Min; +    } + +    @Override +    public DoubleAverageInfo duration5Min() { +        if (!this.durationSupported) { +            return null; +        } +        return this.tickDuration5Min; +    } + + +    /** +     * Rolling average calculator. +     * +     * <p>This code is taken from PaperMC/Paper, licensed under MIT.</p> +     * +     * @author aikar (PaperMC) https://github.com/PaperMC/Paper/blob/master/Spigot-Server-Patches/0021-Further-improve-server-tick-loop.patch +     */ +    public static final class TpsRollingAverage { +        private final int size; +        private long time; +        private BigDecimal total; +        private int index = 0; +        private final BigDecimal[] samples; +        private final long[] times; + +        TpsRollingAverage(int size) { +            this.size = size; +            this.time = size * SEC_IN_NANO; +            this.total = new BigDecimal(TPS).multiply(new BigDecimal(SEC_IN_NANO)).multiply(new BigDecimal(size)); +            this.samples = new BigDecimal[size]; +            this.times = new long[size]; +            for (int i = 0; i < size; i++) { +                this.samples[i] = new BigDecimal(TPS); +                this.times[i] = SEC_IN_NANO; +            } +        } + +        public void add(BigDecimal x, long t, BigDecimal total) { +            this.time -= this.times[this.index]; +            this.total = this.total.subtract(this.samples[this.index].multiply(new BigDecimal(this.times[this.index]))); +            this.samples[this.index] = x; +            this.times[this.index] = t; +            this.time += t; +            this.total = this.total.add(total); +            if (++this.index == this.size) { +                this.index = 0; +            } +        } + +        public double getAverage() { +            return this.total.divide(new BigDecimal(this.time), 30, RoundingMode.HALF_UP).doubleValue(); +        } +    } + +} diff --git a/spark-common/src/main/java/me/lucko/spark/common/monitor/tick/TickStatistics.java b/spark-common/src/main/java/me/lucko/spark/common/monitor/tick/TickStatistics.java index bd2b834..a48b178 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/monitor/tick/TickStatistics.java +++ b/spark-common/src/main/java/me/lucko/spark/common/monitor/tick/TickStatistics.java @@ -20,168 +20,23 @@  package me.lucko.spark.common.monitor.tick; -import me.lucko.spark.common.tick.TickHook; -import me.lucko.spark.common.tick.TickReporter; -import me.lucko.spark.common.util.RollingAverage; - -import java.math.BigDecimal; -import java.math.RoundingMode; -import java.util.concurrent.TimeUnit; +import me.lucko.spark.api.statistic.misc.DoubleAverageInfo;  /** - * Calculates the servers TPS (ticks per second) rate. - * - * <p>The code use to calculate the TPS is the same as the code used by the Minecraft server itself. - * This means that this class will output values the same as the /tps command.</p> - * - * <p>We calculate our own values instead of pulling them from the server for two reasons. Firstly, - * it's easier - pulling from the server requires reflection code on each of the platforms, we'd - * rather avoid that. Secondly, it allows us to generate rolling averages over a shorter period of - * time.</p> + * Provides the server TPS (ticks per second) and MSPT (milliseconds per tick) rates.   */ -public class TickStatistics implements TickHook.Callback, TickReporter.Callback { - -    private static final long SEC_IN_NANO = TimeUnit.SECONDS.toNanos(1); -    private static final int TPS = 20; -    private static final int TPS_SAMPLE_INTERVAL = 20; -    private static final BigDecimal TPS_BASE = new BigDecimal(SEC_IN_NANO).multiply(new BigDecimal(TPS_SAMPLE_INTERVAL)); - -    private final TpsRollingAverage tps5Sec = new TpsRollingAverage(5); -    private final TpsRollingAverage tps10Sec = new TpsRollingAverage(10); -    private final TpsRollingAverage tps1Min = new TpsRollingAverage(60); -    private final TpsRollingAverage tps5Min = new TpsRollingAverage(60 * 5); -    private final TpsRollingAverage tps15Min = new TpsRollingAverage(60 * 15); -    private final TpsRollingAverage[] tpsAverages = {this.tps5Sec, this.tps10Sec, this.tps1Min, this.tps5Min, this.tps15Min}; - -    private boolean durationSupported = false; -    private final RollingAverage tickDuration10Sec = new RollingAverage(TPS * 10); -    private final RollingAverage tickDuration1Min = new RollingAverage(TPS * 60); -    private final RollingAverage tickDuration5Min = new RollingAverage(TPS * 60 * 5); -    private final RollingAverage[] tickDurationAverages = {this.tickDuration10Sec, this.tickDuration1Min, this.tickDuration5Min}; - -    private long last = 0; - -    public boolean isDurationSupported() { -        return this.durationSupported; -    } - -    @Override -    public void onTick(int currentTick) { -        if (currentTick % TPS_SAMPLE_INTERVAL != 0) { -            return; -        } - -        long now = System.nanoTime(); - -        if (this.last == 0) { -            this.last = now; -            return; -        } - -        long diff = now - this.last; -        BigDecimal currentTps = TPS_BASE.divide(new BigDecimal(diff), 30, RoundingMode.HALF_UP); -        BigDecimal total = currentTps.multiply(new BigDecimal(diff)); - -        for (TpsRollingAverage rollingAverage : this.tpsAverages) { -            rollingAverage.add(currentTps, diff, total); -        } - -        this.last = now; -    } - -    @Override -    public void onTick(double duration) { -        this.durationSupported = true; -        BigDecimal decimal = new BigDecimal(duration); -        for (RollingAverage rollingAverage : this.tickDurationAverages) { -            rollingAverage.add(decimal); -        } -    } - -    public double tps5Sec() { -        return this.tps5Sec.getAverage(); -    } - -    public double tps10Sec() { -        return this.tps10Sec.getAverage(); -    } - -    public double tps1Min() { -        return this.tps1Min.getAverage(); -    } - -    public double tps5Min() { -        return this.tps5Min.getAverage(); -    } - -    public double tps15Min() { -        return this.tps15Min.getAverage(); -    } - -    public RollingAverage duration10Sec() { -        if (!this.durationSupported) { -            return null; -        } -        return this.tickDuration10Sec; -    } - -    public RollingAverage duration1Min() { -        if (!this.durationSupported) { -            return null; -        } -        return this.tickDuration1Min; -    } - -    public RollingAverage duration5Min() { -        if (!this.durationSupported) { -            return null; -        } -        return this.tickDuration5Min; -    } - - -    /** -     * Rolling average calculator. -     * -     * <p>This code is taken from PaperMC/Paper, licensed under MIT.</p> -     * -     * @author aikar (PaperMC) https://github.com/PaperMC/Paper/blob/master/Spigot-Server-Patches/0021-Further-improve-server-tick-loop.patch -     */ -    public static final class TpsRollingAverage { -        private final int size; -        private long time; -        private BigDecimal total; -        private int index = 0; -        private final BigDecimal[] samples; -        private final long[] times; +public interface TickStatistics { -        TpsRollingAverage(int size) { -            this.size = size; -            this.time = size * SEC_IN_NANO; -            this.total = new BigDecimal(TPS).multiply(new BigDecimal(SEC_IN_NANO)).multiply(new BigDecimal(size)); -            this.samples = new BigDecimal[size]; -            this.times = new long[size]; -            for (int i = 0; i < size; i++) { -                this.samples[i] = new BigDecimal(TPS); -                this.times[i] = SEC_IN_NANO; -            } -        } +    double tps5Sec(); +    double tps10Sec(); +    double tps1Min(); +    double tps5Min(); +    double tps15Min(); -        public void add(BigDecimal x, long t, BigDecimal total) { -            this.time -= this.times[this.index]; -            this.total = this.total.subtract(this.samples[this.index].multiply(new BigDecimal(this.times[this.index]))); -            this.samples[this.index] = x; -            this.times[this.index] = t; -            this.time += t; -            this.total = this.total.add(total); -            if (++this.index == this.size) { -                this.index = 0; -            } -        } +    boolean isDurationSupported(); -        public double getAverage() { -            return this.total.divide(new BigDecimal(this.time), 30, RoundingMode.HALF_UP).doubleValue(); -        } -    } +    DoubleAverageInfo duration10Sec(); +    DoubleAverageInfo duration1Min(); +    DoubleAverageInfo duration5Min();  } diff --git a/spark-common/src/main/java/me/lucko/spark/common/platform/PlatformStatisticsProvider.java b/spark-common/src/main/java/me/lucko/spark/common/platform/PlatformStatisticsProvider.java index fc7e78a..b0987c9 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/platform/PlatformStatisticsProvider.java +++ b/spark-common/src/main/java/me/lucko/spark/common/platform/PlatformStatisticsProvider.java @@ -20,6 +20,7 @@  package me.lucko.spark.common.platform; +import me.lucko.spark.api.statistic.misc.DoubleAverageInfo;  import me.lucko.spark.common.SparkPlatform;  import me.lucko.spark.common.monitor.cpu.CpuInfo;  import me.lucko.spark.common.monitor.cpu.CpuMonitor; @@ -33,6 +34,7 @@ import me.lucko.spark.common.monitor.ping.PingStatistics;  import me.lucko.spark.common.monitor.tick.TickStatistics;  import me.lucko.spark.common.platform.world.AsyncWorldInfoProvider;  import me.lucko.spark.common.platform.world.WorldStatisticsProvider; +import me.lucko.spark.proto.SparkProtos;  import me.lucko.spark.proto.SparkProtos.PlatformStatistics;  import me.lucko.spark.proto.SparkProtos.SystemStatistics;  import me.lucko.spark.proto.SparkProtos.WorldStatistics; @@ -118,17 +120,17 @@ public class PlatformStatisticsProvider {          networkInterfaceStats.forEach((name, statistics) -> builder.putNet(                  name,                  SystemStatistics.NetInterface.newBuilder() -                        .setRxBytesPerSecond(statistics.rxBytesPerSecond().toProto()) -                        .setRxPacketsPerSecond(statistics.rxPacketsPerSecond().toProto()) -                        .setTxBytesPerSecond(statistics.txBytesPerSecond().toProto()) -                        .setTxPacketsPerSecond(statistics.txPacketsPerSecond().toProto()) +                        .setRxBytesPerSecond(rollingAvgProto(statistics.rxBytesPerSecond())) +                        .setRxPacketsPerSecond(rollingAvgProto(statistics.rxPacketsPerSecond())) +                        .setTxBytesPerSecond(rollingAvgProto(statistics.txBytesPerSecond())) +                        .setTxPacketsPerSecond(rollingAvgProto(statistics.txPacketsPerSecond()))                          .build()          ));          return builder.build();      } -    public PlatformStatistics getPlatformStatistics(Map<String, GarbageCollectorStatistics> startingGcStatistics) { +    public PlatformStatistics getPlatformStatistics(Map<String, GarbageCollectorStatistics> startingGcStatistics, boolean includeWorld) {          PlatformStatistics.Builder builder = PlatformStatistics.newBuilder();          MemoryUsage memoryUsage = ManagementFactory.getMemoryMXBean().getHeapMemoryUsage(); @@ -166,8 +168,8 @@ public class PlatformStatisticsProvider {              );              if (tickStatistics.isDurationSupported()) {                  builder.setMspt(PlatformStatistics.Mspt.newBuilder() -                        .setLast1M(tickStatistics.duration1Min().toProto()) -                        .setLast5M(tickStatistics.duration5Min().toProto()) +                        .setLast1M(rollingAvgProto(tickStatistics.duration1Min())) +                        .setLast5M(rollingAvgProto(tickStatistics.duration5Min()))                          .build()                  );              } @@ -176,7 +178,7 @@ public class PlatformStatisticsProvider {          PingStatistics pingStatistics = this.platform.getPingStatistics();          if (pingStatistics != null && pingStatistics.getPingAverage().getSamples() != 0) {              builder.setPing(PlatformStatistics.Ping.newBuilder() -                    .setLast15M(pingStatistics.getPingAverage().toProto()) +                    .setLast15M(rollingAvgProto(pingStatistics.getPingAverage()))                      .build()              );          } @@ -187,20 +189,31 @@ public class PlatformStatisticsProvider {              builder.setPlayerCount(playerCount);          } -        try { -            WorldStatisticsProvider worldStatisticsProvider = new WorldStatisticsProvider( -                    new AsyncWorldInfoProvider(this.platform, this.platform.getPlugin().createWorldInfoProvider()) -            ); -            WorldStatistics worldStatistics = worldStatisticsProvider.getWorldStatistics(); -            if (worldStatistics != null) { -                builder.setWorld(worldStatistics); +        if (includeWorld) { +            try { +                WorldStatisticsProvider worldStatisticsProvider = new WorldStatisticsProvider( +                        new AsyncWorldInfoProvider(this.platform, this.platform.getPlugin().createWorldInfoProvider()) +                ); +                WorldStatistics worldStatistics = worldStatisticsProvider.getWorldStatistics(); +                if (worldStatistics != null) { +                    builder.setWorld(worldStatistics); +                } +            } catch (Exception e) { +                e.printStackTrace();              } -        } catch (Exception e) { -            e.printStackTrace();          } -          return builder.build();      } +    public static SparkProtos.RollingAverageValues rollingAvgProto(DoubleAverageInfo info) { +        return SparkProtos.RollingAverageValues.newBuilder() +                .setMean(info.mean()) +                .setMax(info.max()) +                .setMin(info.min()) +                .setMedian(info.median()) +                .setPercentile95(info.percentile95th()) +                .build(); +    } +  } diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/AbstractSampler.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/AbstractSampler.java index e324fd3..d814002 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/sampler/AbstractSampler.java +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/AbstractSampler.java @@ -32,9 +32,12 @@ import me.lucko.spark.common.sampler.source.ClassSourceLookup;  import me.lucko.spark.common.sampler.source.SourceMetadata;  import me.lucko.spark.common.sampler.window.ProtoTimeEncoder;  import me.lucko.spark.common.sampler.window.WindowStatisticsCollector; +import me.lucko.spark.common.ws.ViewerSocket; +import me.lucko.spark.proto.SparkProtos;  import me.lucko.spark.proto.SparkSamplerProtos.SamplerData;  import me.lucko.spark.proto.SparkSamplerProtos.SamplerMetadata; +import java.util.ArrayList;  import java.util.Collection;  import java.util.Comparator;  import java.util.List; @@ -74,6 +77,9 @@ public abstract class AbstractSampler implements Sampler {      /** The garbage collector statistics when profiling started */      protected Map<String, GarbageCollectorStatistics> initialGcStats; +    /** A set of viewer sockets linked to the sampler */ +    protected List<ViewerSocket> viewerSockets = new ArrayList<>(); +      protected AbstractSampler(SparkPlatform platform, SamplerSettings settings) {          this.platform = platform;          this.interval = settings.interval(); @@ -122,12 +128,54 @@ public abstract class AbstractSampler implements Sampler {      @Override      public void stop(boolean cancelled) {          this.windowStatisticsCollector.stop(); +        for (ViewerSocket viewerSocket : this.viewerSockets) { +            viewerSocket.processSamplerStopped(this); +        } +    } + +    @Override +    public void attachSocket(ViewerSocket socket) { +        this.viewerSockets.add(socket); +    } + +    @Override +    public Collection<ViewerSocket> getAttachedSockets() { +        return this.viewerSockets; +    } + +    protected void processWindowRotate() { +        this.viewerSockets.removeIf(socket -> { +            if (!socket.isOpen()) { +                return true; +            } + +            socket.processWindowRotate(this); +            return false; +        }); +    } + +    protected void sendStatisticsToSocket() { +        try { +            if (this.viewerSockets.isEmpty()) { +                return; +            } + +            SparkProtos.PlatformStatistics platform = this.platform.getStatisticsProvider().getPlatformStatistics(getInitialGcStats(), false); +            SparkProtos.SystemStatistics system = this.platform.getStatisticsProvider().getSystemStatistics(); + +            for (ViewerSocket viewerSocket : this.viewerSockets) { +                viewerSocket.sendUpdatedStatistics(platform, system); +            } +        } catch (Exception e) { +            e.printStackTrace(); +        }      } -    protected void writeMetadataToProto(SamplerData.Builder proto, SparkPlatform platform, CommandSender creator, String comment, DataAggregator dataAggregator) { +    protected void writeMetadataToProto(SamplerData.Builder proto, SparkPlatform platform, CommandSender.Data creator, String comment, DataAggregator dataAggregator) {          SamplerMetadata.Builder metadata = SamplerMetadata.newBuilder() +                .setSamplerMode(getMode().asProto())                  .setPlatformMetadata(platform.getPlugin().getPlatformInfo().toData().toProto()) -                .setCreator(creator.toData().toProto()) +                .setCreator(creator.toProto())                  .setStartTime(this.startTime)                  .setEndTime(System.currentTimeMillis())                  .setInterval(this.interval) @@ -144,7 +192,7 @@ public abstract class AbstractSampler implements Sampler {          }          try { -            metadata.setPlatformStatistics(platform.getStatisticsProvider().getPlatformStatistics(getInitialGcStats())); +            metadata.setPlatformStatistics(platform.getStatisticsProvider().getPlatformStatistics(getInitialGcStats(), true));          } catch (Exception e) {              e.printStackTrace();          } @@ -187,7 +235,7 @@ public abstract class AbstractSampler implements Sampler {          ClassSourceLookup.Visitor classSourceVisitor = ClassSourceLookup.createVisitor(classSourceLookup); -        ProtoTimeEncoder timeEncoder = new ProtoTimeEncoder(data); +        ProtoTimeEncoder timeEncoder = new ProtoTimeEncoder(getMode().valueTransformer(), data);          int[] timeWindows = timeEncoder.getKeys();          for (int timeWindow : timeWindows) {              proto.addTimeWindows(timeWindow); diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/BackgroundSamplerManager.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/BackgroundSamplerManager.java index 7e3b6b4..4e9ca9e 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/sampler/BackgroundSamplerManager.java +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/BackgroundSamplerManager.java @@ -31,6 +31,8 @@ public class BackgroundSamplerManager {      private static final String OPTION_ENABLED = "backgroundProfiler";      private static final String OPTION_ENGINE = "backgroundProfilerEngine";      private static final String OPTION_INTERVAL = "backgroundProfilerInterval"; +    private static final String OPTION_THREAD_GROUPER = "backgroundProfilerThreadGrouper"; +    private static final String OPTION_THREAD_DUMPER = "backgroundProfilerThreadDumper";      private static final String MARKER_FAILED = "_marker_background_profiler_failed"; @@ -101,13 +103,21 @@ public class BackgroundSamplerManager {      private void startSampler() {          boolean forceJavaEngine = this.configuration.getString(OPTION_ENGINE, "async").equals("java"); +        ThreadGrouper threadGrouper = ThreadGrouper.parseConfigSetting(this.configuration.getString(OPTION_THREAD_GROUPER, "by-pool")); +        ThreadDumper threadDumper = ThreadDumper.parseConfigSetting(this.configuration.getString(OPTION_THREAD_DUMPER, "default")); +        if (threadDumper == null) { +            threadDumper = this.platform.getPlugin().getDefaultThreadDumper(); +        } + +        int interval = this.configuration.getInteger(OPTION_INTERVAL, 10); +          Sampler sampler = new SamplerBuilder() -                .background(true) -                .threadDumper(this.platform.getPlugin().getDefaultThreadDumper()) -                .threadGrouper(ThreadGrouper.BY_POOL) -                .samplingInterval(this.configuration.getInteger(OPTION_INTERVAL, 10)) -                .forceJavaSampler(forceJavaEngine) -                .start(this.platform); +              .background(true) +              .threadDumper(threadDumper) +              .threadGrouper(threadGrouper) +              .samplingInterval(interval) +              .forceJavaSampler(forceJavaEngine) +              .start(this.platform);          this.platform.getSamplerContainer().setActiveSampler(sampler);      } diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/Sampler.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/Sampler.java index 36a63f1..844ab0b 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/sampler/Sampler.java +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/Sampler.java @@ -24,9 +24,13 @@ import me.lucko.spark.common.SparkPlatform;  import me.lucko.spark.common.command.sender.CommandSender;  import me.lucko.spark.common.sampler.node.MergeMode;  import me.lucko.spark.common.sampler.source.ClassSourceLookup; +import me.lucko.spark.common.ws.ViewerSocket;  import me.lucko.spark.proto.SparkSamplerProtos.SamplerData; +import me.lucko.spark.proto.SparkSamplerProtos.SocketChannelInfo; +import java.util.Collection;  import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier;  /**   * Abstract superinterface for all sampler implementations. @@ -44,6 +48,20 @@ public interface Sampler {      void stop(boolean cancelled);      /** +     * Attaches a viewer socket to this sampler. +     * +     * @param socket the socket +     */ +    void attachSocket(ViewerSocket socket); + +    /** +     * Gets the sockets attached to this sampler. +     * +     * @return the attached sockets +     */ +    Collection<ViewerSocket> getAttachedSockets(); + +    /**       * Gets the time when the sampler started (unix timestamp in millis)       *       * @return the start time @@ -65,6 +83,13 @@ public interface Sampler {      boolean isRunningInBackground();      /** +     * Gets the sampler mode. +     * +     * @return the sampler mode +     */ +    SamplerMode getMode(); + +    /**       * Gets a future to encapsulate the completion of the sampler       *       * @return a future @@ -72,6 +97,62 @@ public interface Sampler {      CompletableFuture<Sampler> getFuture();      // Methods used to export the sampler data to the web viewer. -    SamplerData toProto(SparkPlatform platform, CommandSender creator, String comment, MergeMode mergeMode, ClassSourceLookup classSourceLookup); +    SamplerData toProto(SparkPlatform platform, ExportProps exportProps); + +    final class ExportProps { +        private CommandSender.Data creator; +        private String comment; +        private Supplier<MergeMode> mergeMode; +        private Supplier<ClassSourceLookup> classSourceLookup; +        private SocketChannelInfo channelInfo; + +        public ExportProps() { +        } + +        public CommandSender.Data creator() { +            return this.creator; +        } + +        public String comment() { +            return this.comment; +        } + +        public Supplier<MergeMode> mergeMode() { +            return this.mergeMode; +        } + +        public Supplier<ClassSourceLookup> classSourceLookup() { +            return this.classSourceLookup; +        } + +        public SocketChannelInfo channelInfo() { +            return this.channelInfo; +        } + +        public ExportProps creator(CommandSender.Data creator) { +            this.creator = creator; +            return this; +        } + +        public ExportProps comment(String comment) { +            this.comment = comment; +            return this; +        } + +        public ExportProps mergeMode(Supplier<MergeMode> mergeMode) { +            this.mergeMode = mergeMode; +            return this; +        } + +        public ExportProps classSourceLookup(Supplier<ClassSourceLookup> classSourceLookup) { +            this.classSourceLookup = classSourceLookup; +            return this; +        } + +        public ExportProps channelInfo(SocketChannelInfo channelInfo) { +            this.channelInfo = channelInfo; +            return this; +        } +    }  } diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/SamplerBuilder.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/SamplerBuilder.java index ec635ef..b6895ce 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/sampler/SamplerBuilder.java +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/SamplerBuilder.java @@ -23,6 +23,7 @@ package me.lucko.spark.common.sampler;  import me.lucko.spark.common.SparkPlatform;  import me.lucko.spark.common.sampler.async.AsyncProfilerAccess;  import me.lucko.spark.common.sampler.async.AsyncSampler; +import me.lucko.spark.common.sampler.async.SampleCollector;  import me.lucko.spark.common.sampler.java.JavaSampler;  import me.lucko.spark.common.tick.TickHook; @@ -34,10 +35,12 @@ import java.util.concurrent.TimeUnit;  @SuppressWarnings("UnusedReturnValue")  public class SamplerBuilder { -    private double samplingInterval = 4; // milliseconds +    private SamplerMode mode = SamplerMode.EXECUTION; +    private double samplingInterval = -1;      private boolean ignoreSleeping = false;      private boolean ignoreNative = false;      private boolean useAsyncProfiler = true; +    private boolean allocLiveOnly = false;      private long autoEndTime = -1;      private boolean background = false;      private ThreadDumper threadDumper = ThreadDumper.ALL; @@ -49,6 +52,11 @@ public class SamplerBuilder {      public SamplerBuilder() {      } +    public SamplerBuilder mode(SamplerMode mode) { +        this.mode = mode; +        return this; +    } +      public SamplerBuilder samplingInterval(double samplingInterval) {          this.samplingInterval = samplingInterval;          return this; @@ -98,21 +106,38 @@ public class SamplerBuilder {          return this;      } -    public Sampler start(SparkPlatform platform) { +    public SamplerBuilder allocLiveOnly(boolean allocLiveOnly) { +        this.allocLiveOnly = allocLiveOnly; +        return this; +    } + +    public Sampler start(SparkPlatform platform) throws UnsupportedOperationException { +        if (this.samplingInterval <= 0) { +            throw new IllegalArgumentException("samplingInterval = " + this.samplingInterval); +        } +          boolean onlyTicksOverMode = this.ticksOver != -1 && this.tickHook != null;          boolean canUseAsyncProfiler = this.useAsyncProfiler &&                  !onlyTicksOverMode &&                  !(this.ignoreSleeping || this.ignoreNative) && -                !(this.threadDumper instanceof ThreadDumper.Regex) &&                  AsyncProfilerAccess.getInstance(platform).checkSupported(platform); +        if (this.mode == SamplerMode.ALLOCATION && (!canUseAsyncProfiler || !AsyncProfilerAccess.getInstance(platform).checkAllocationProfilingSupported(platform))) { +            throw new UnsupportedOperationException("Allocation profiling is not supported on your system. Check the console for more info."); +        } + +        int interval = (int) (this.mode == SamplerMode.EXECUTION ? +                this.samplingInterval * 1000d : // convert to microseconds +                this.samplingInterval +        ); -        int intervalMicros = (int) (this.samplingInterval * 1000d); -        SamplerSettings settings = new SamplerSettings(intervalMicros, this.threadDumper, this.threadGrouper, this.autoEndTime, this.background); +        SamplerSettings settings = new SamplerSettings(interval, this.threadDumper, this.threadGrouper, this.autoEndTime, this.background);          Sampler sampler; -        if (canUseAsyncProfiler) { -            sampler = new AsyncSampler(platform, settings); +        if (this.mode == SamplerMode.ALLOCATION) { +            sampler = new AsyncSampler(platform, settings, new SampleCollector.Allocation(interval, this.allocLiveOnly)); +        } else if (canUseAsyncProfiler) { +            sampler = new AsyncSampler(platform, settings, new SampleCollector.Execution(interval));          } else if (onlyTicksOverMode) {              sampler = new JavaSampler(platform, settings, this.ignoreSleeping, this.ignoreNative, this.tickHook, this.ticksOver);          } else { diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/SamplerMode.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/SamplerMode.java new file mode 100644 index 0000000..f9a6e41 --- /dev/null +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/SamplerMode.java @@ -0,0 +1,74 @@ +/* + * This file is part of spark. + * + *  Copyright (c) lucko (Luck) <luck@lucko.me> + *  Copyright (c) contributors + * + *  This program is free software: you can redistribute it and/or modify + *  it under the terms of the GNU General Public License as published by + *  the Free Software Foundation, either version 3 of the License, or + *  (at your option) any later version. + * + *  This program is distributed in the hope that it will be useful, + *  but WITHOUT ANY WARRANTY; without even the implied warranty of + *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the + *  GNU General Public License for more details. + * + *  You should have received a copy of the GNU General Public License + *  along with this program.  If not, see <http://www.gnu.org/licenses/>. + */ + +package me.lucko.spark.common.sampler; + +import me.lucko.spark.proto.SparkSamplerProtos.SamplerMetadata; + +import java.util.function.LongToDoubleFunction; + +public enum SamplerMode { + +    EXECUTION( +            value -> { +                // convert the duration from microseconds -> milliseconds +                return value / 1000d; +            }, +            4, // ms +            SamplerMetadata.SamplerMode.EXECUTION +    ), + +    ALLOCATION( +            value -> { +                // do nothing +                return value; +            }, +            524287, // 512 KiB +            SamplerMetadata.SamplerMode.ALLOCATION +    ); + +    private final LongToDoubleFunction valueTransformer; +    private final int defaultInterval; +    private final SamplerMetadata.SamplerMode proto; + +    SamplerMode(LongToDoubleFunction valueTransformer, int defaultInterval, SamplerMetadata.SamplerMode proto) { +        this.valueTransformer = valueTransformer; +        this.defaultInterval = defaultInterval; +        this.proto = proto; +    } + +    public LongToDoubleFunction valueTransformer() { +        return this.valueTransformer; +    } + +    public int defaultInterval() { +        return this.defaultInterval; +    } + +    /** +     * Gets the metadata enum instance for this sampler mode. +     * +     * @return proto metadata +     */ +    public SamplerMetadata.SamplerMode asProto() { +        return this.proto; +    } + +} diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/ThreadDumper.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/ThreadDumper.java index fd0c413..c68384b 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/sampler/ThreadDumper.java +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/ThreadDumper.java @@ -32,7 +32,6 @@ import java.util.Objects;  import java.util.Set;  import java.util.function.Supplier;  import java.util.regex.Pattern; -import java.util.regex.PatternSyntaxException;  import java.util.stream.Collectors;  /** @@ -50,11 +49,38 @@ public interface ThreadDumper {      ThreadInfo[] dumpThreads(ThreadMXBean threadBean);      /** +     * Gets if the given thread should be included in the output. +     * +     * @param threadId the thread id +     * @param threadName the thread name +     * @return if the thread should be included +     */ +    boolean isThreadIncluded(long threadId, String threadName); + +    /**       * Gets metadata about the thread dumper instance.       */      SamplerMetadata.ThreadDumper getMetadata();      /** +     * Creates a new {@link ThreadDumper} by parsing the given config setting. +     * +     * @param setting the config setting +     * @return the thread dumper +     */ +    static ThreadDumper parseConfigSetting(String setting) { +        switch (setting) { +            case "default": +                return null; +            case "all": +                return ALL; +            default: +                Set<String> threadNames = Arrays.stream(setting.split(",")).collect(Collectors.toSet()); +                return new ThreadDumper.Specific(threadNames); +        } +    } + +    /**       * Implementation of {@link ThreadDumper} that generates data for all threads.       */      ThreadDumper ALL = new ThreadDumper() { @@ -64,6 +90,11 @@ public interface ThreadDumper {          }          @Override +        public boolean isThreadIncluded(long threadId, String threadName) { +            return true; +        } + +        @Override          public SamplerMetadata.ThreadDumper getMetadata() {              return SamplerMetadata.ThreadDumper.newBuilder()                      .setType(SamplerMetadata.ThreadDumper.Type.ALL) @@ -98,7 +129,7 @@ public interface ThreadDumper {          }          public void setThread(Thread thread) { -            this.dumper = new Specific(new long[]{thread.getId()}); +            this.dumper = new Specific(thread);          }      } @@ -114,10 +145,6 @@ public interface ThreadDumper {              this.ids = new long[]{thread.getId()};          } -        public Specific(long[] ids) { -            this.ids = ids; -        } -          public Specific(Set<String> names) {              this.threadNamesLowerCase = names.stream().map(String::toLowerCase).collect(Collectors.toSet());              this.ids = new ThreadFinder().getThreads() @@ -146,6 +173,14 @@ public interface ThreadDumper {          }          @Override +        public boolean isThreadIncluded(long threadId, String threadName) { +            if (Arrays.binarySearch(this.ids, threadId) >= 0) { +                return true; +            } +            return getThreadNames().contains(threadName.toLowerCase()); +        } + +        @Override          public ThreadInfo[] dumpThreads(ThreadMXBean threadBean) {              return threadBean.getThreadInfo(this.ids, Integer.MAX_VALUE);          } @@ -169,35 +204,31 @@ public interface ThreadDumper {          public Regex(Set<String> namePatterns) {              this.namePatterns = namePatterns.stream() -                    .map(regex -> { -                        try { -                            return Pattern.compile(regex, Pattern.CASE_INSENSITIVE); -                        } catch (PatternSyntaxException e) { -                            return null; -                        } -                    }) -                    .filter(Objects::nonNull) +                    .map(regex -> Pattern.compile(regex, Pattern.CASE_INSENSITIVE))                      .collect(Collectors.toSet());          }          @Override +        public boolean isThreadIncluded(long threadId, String threadName) { +            Boolean result = this.cache.get(threadId); +            if (result != null) { +                return result; +            } + +            for (Pattern pattern : this.namePatterns) { +                if (pattern.matcher(threadName).matches()) { +                    this.cache.put(threadId, true); +                    return true; +                } +            } +            this.cache.put(threadId, false); +            return false; +        } + +        @Override          public ThreadInfo[] dumpThreads(ThreadMXBean threadBean) {              return this.threadFinder.getThreads() -                    .filter(thread -> { -                        Boolean result = this.cache.get(thread.getId()); -                        if (result != null) { -                            return result; -                        } - -                        for (Pattern pattern : this.namePatterns) { -                            if (pattern.matcher(thread.getName()).matches()) { -                                this.cache.put(thread.getId(), true); -                                return true; -                            } -                        } -                        this.cache.put(thread.getId(), false); -                        return false; -                    }) +                    .filter(thread -> isThreadIncluded(thread.getId(), thread.getName()))                      .map(thread -> threadBean.getThreadInfo(thread.getId(), Integer.MAX_VALUE))                      .filter(Objects::nonNull)                      .toArray(ThreadInfo[]::new); diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/ThreadGrouper.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/ThreadGrouper.java index 9ad84df..b6cfbea 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/sampler/ThreadGrouper.java +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/ThreadGrouper.java @@ -35,6 +35,47 @@ import java.util.regex.Pattern;  public interface ThreadGrouper {      /** +     * Gets the group for the given thread. +     * +     * @param threadId the id of the thread +     * @param threadName the name of the thread +     * @return the group +     */ +    String getGroup(long threadId, String threadName); + +    /** +     * Gets the label to use for a given group. +     * +     * @param group the group +     * @return the label +     */ +    String getLabel(String group); + +    /** +     * Gets the metadata enum instance for this thread grouper. +     * +     * @return proto metadata +     */ +    SamplerMetadata.DataAggregator.ThreadGrouper asProto(); + +    /** +     * Creates a new {@link ThreadGrouper} by parsing the given config setting. +     * +     * @param setting the config setting +     * @return the thread grouper +     */ +    static ThreadGrouper parseConfigSetting(String setting) { +        switch (setting) { +            case "as-one": +                return AS_ONE; +            case "by-name": +                return BY_NAME; +            default: +                return BY_POOL; +        } +    } + +    /**       * Implementation of {@link ThreadGrouper} that just groups by thread name.       */      ThreadGrouper BY_NAME = new ThreadGrouper() { @@ -126,23 +167,4 @@ public interface ThreadGrouper {          }      }; -    /** -     * Gets the group for the given thread. -     * -     * @param threadId the id of the thread -     * @param threadName the name of the thread -     * @return the group -     */ -    String getGroup(long threadId, String threadName); - -    /** -     * Gets the label to use for a given group. -     * -     * @param group the group -     * @return the label -     */ -    String getLabel(String group); - -    SamplerMetadata.DataAggregator.ThreadGrouper asProto(); -  } diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/async/AsyncDataAggregator.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/async/AsyncDataAggregator.java index 402330a..b9a80e0 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/sampler/async/AsyncDataAggregator.java +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/async/AsyncDataAggregator.java @@ -50,7 +50,7 @@ public class AsyncDataAggregator extends AbstractDataAggregator {      public void insertData(ProfileSegment element, int window) {          try {              ThreadNode node = getNode(this.threadGrouper.getGroup(element.getNativeThreadId(), element.getThreadName())); -            node.log(STACK_TRACE_DESCRIBER, element.getStackTrace(), element.getTime(), window); +            node.log(STACK_TRACE_DESCRIBER, element.getStackTrace(), element.getValue(), window);          } catch (Exception e) {              e.printStackTrace();          } diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/async/AsyncProfilerAccess.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/async/AsyncProfilerAccess.java index 1480650..5bee56f 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/sampler/async/AsyncProfilerAccess.java +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/async/AsyncProfilerAccess.java @@ -61,6 +61,8 @@ public class AsyncProfilerAccess {      /** The event to use for profiling */      private final ProfilingEvent profilingEvent; +    /** The event to use for allocation profiling */ +    private final ProfilingEvent allocationProfilingEvent;      /** If profiler is null, contains the reason why setup failed */      private final Exception setupException; @@ -68,10 +70,16 @@ public class AsyncProfilerAccess {      AsyncProfilerAccess(SparkPlatform platform) {          AsyncProfiler profiler;          ProfilingEvent profilingEvent = null; +        ProfilingEvent allocationProfilingEvent = null;          Exception setupException = null;          try {              profiler = load(platform); + +            if (isEventSupported(profiler, ProfilingEvent.ALLOC, false)) { +                allocationProfilingEvent = ProfilingEvent.ALLOC; +            } +              if (isEventSupported(profiler, ProfilingEvent.CPU, false)) {                  profilingEvent = ProfilingEvent.CPU;              } else if (isEventSupported(profiler, ProfilingEvent.WALL, true)) { @@ -84,6 +92,7 @@ public class AsyncProfilerAccess {          this.profiler = profiler;          this.profilingEvent = profilingEvent; +        this.allocationProfilingEvent = allocationProfilingEvent;          this.setupException = setupException;      } @@ -98,6 +107,10 @@ public class AsyncProfilerAccess {          return this.profilingEvent;      } +    public ProfilingEvent getAllocationProfilingEvent() { +        return this.allocationProfilingEvent; +    } +      public boolean checkSupported(SparkPlatform platform) {          if (this.setupException != null) {              if (this.setupException instanceof UnsupportedSystemException) { @@ -116,6 +129,15 @@ public class AsyncProfilerAccess {          return this.profiler != null;      } +    public boolean checkAllocationProfilingSupported(SparkPlatform platform) { +        boolean supported = this.allocationProfilingEvent != null; +        if (!supported && this.profiler != null) { +            platform.getPlugin().log(Level.WARNING, "The allocation profiling mode is not supported on your system. This is most likely because Hotspot debug symbols are not available."); +            platform.getPlugin().log(Level.WARNING, "To resolve, try installing the 'openjdk-11-dbg' or 'openjdk-8-dbg' package using your OS package manager."); +        } +        return supported; +    } +      private static AsyncProfiler load(SparkPlatform platform) throws Exception {          // check compatibility          String os = System.getProperty("os.name").toLowerCase(Locale.ROOT).replace(" ", ""); @@ -183,7 +205,8 @@ public class AsyncProfilerAccess {      enum ProfilingEvent {          CPU(Events.CPU), -        WALL(Events.WALL); +        WALL(Events.WALL), +        ALLOC(Events.ALLOC);          private final String id; diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/async/AsyncProfilerJob.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/async/AsyncProfilerJob.java index d74b75f..2fd304c 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/sampler/async/AsyncProfilerJob.java +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/async/AsyncProfilerJob.java @@ -20,6 +20,8 @@  package me.lucko.spark.common.sampler.async; +import com.google.common.collect.ImmutableList; +  import me.lucko.spark.common.SparkPlatform;  import me.lucko.spark.common.sampler.ThreadDumper;  import me.lucko.spark.common.sampler.async.jfr.JfrReader; @@ -29,10 +31,9 @@ import one.profiler.AsyncProfiler;  import java.io.IOException;  import java.nio.file.Files;  import java.nio.file.Path; +import java.util.Collection;  import java.util.List; -import java.util.concurrent.TimeUnit;  import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Predicate;  /**   * Represents a profiling job within async-profiler. @@ -77,8 +78,8 @@ public class AsyncProfilerJob {      // Set on init      /** The platform */      private SparkPlatform platform; -    /** The sampling interval in microseconds */ -    private int interval; +    /** The sample collector */ +    private SampleCollector<?> sampleCollector;      /** The thread dumper */      private ThreadDumper threadDumper;      /** The profiling window */ @@ -100,9 +101,9 @@ public class AsyncProfilerJob {       * @param command the command       * @return the output       */ -    private String execute(String command) { +    private String execute(Collection<String> command) {          try { -            return this.profiler.execute(command); +            return this.profiler.execute(String.join(",", command));          } catch (IOException e) {              throw new RuntimeException("Exception whilst executing profiler command", e);          } @@ -118,9 +119,9 @@ public class AsyncProfilerJob {      }      // Initialise the job -    public void init(SparkPlatform platform, int interval, ThreadDumper threadDumper, int window, boolean quiet) { +    public void init(SparkPlatform platform, SampleCollector<?> collector, ThreadDumper threadDumper, int window, boolean quiet) {          this.platform = platform; -        this.interval = interval; +        this.sampleCollector = collector;          this.threadDumper = threadDumper;          this.window = window;          this.quiet = quiet; @@ -141,16 +142,20 @@ public class AsyncProfilerJob {              }              // construct a command to send to async-profiler -            String command = "start,event=" + this.access.getProfilingEvent() + ",interval=" + this.interval + "us,threads,jfr,file=" + this.outputFile.toString(); +            ImmutableList.Builder<String> command = ImmutableList.<String>builder() +                    .add("start") +                    .addAll(this.sampleCollector.initArguments(this.access)) +                    .add("threads").add("jfr").add("file=" + this.outputFile.toString()); +              if (this.quiet) { -                command += ",loglevel=NONE"; +                command.add("loglevel=NONE");              }              if (this.threadDumper instanceof ThreadDumper.Specific) { -                command += ",filter"; +                command.add("filter");              }              // start the profiler -            String resp = execute(command).trim(); +            String resp = execute(command.build()).trim();              if (!resp.equalsIgnoreCase("profiling started")) {                  throw new RuntimeException("Unexpected response: " + resp); @@ -197,18 +202,9 @@ public class AsyncProfilerJob {       * Aggregates the collected data.       */      public void aggregate(AsyncDataAggregator dataAggregator) { - -        Predicate<String> threadFilter; -        if (this.threadDumper instanceof ThreadDumper.Specific) { -            ThreadDumper.Specific specificDumper = (ThreadDumper.Specific) this.threadDumper; -            threadFilter = n -> specificDumper.getThreadNames().contains(n.toLowerCase()); -        } else { -            threadFilter = n -> true; -        } -          // read the jfr file produced by async-profiler          try (JfrReader reader = new JfrReader(this.outputFile)) { -            readSegments(reader, threadFilter, dataAggregator, this.window); +            readSegments(reader, this.sampleCollector, dataAggregator);          } catch (Exception e) {              boolean fileExists;              try { @@ -235,34 +231,23 @@ public class AsyncProfilerJob {          }      } -    private void readSegments(JfrReader reader, Predicate<String> threadFilter, AsyncDataAggregator dataAggregator, int window) throws IOException { -        List<JfrReader.ExecutionSample> samples = reader.readAllEvents(JfrReader.ExecutionSample.class); -        for (int i = 0; i < samples.size(); i++) { -            JfrReader.ExecutionSample sample = samples.get(i); - -            long duration; -            if (i == 0) { -                // we don't really know the duration of the first sample, so just use the sampling -                // interval -                duration = this.interval; -            } else { -                // calculate the duration of the sample by calculating the time elapsed since the -                // previous sample -                duration = TimeUnit.NANOSECONDS.toMicros(sample.time - samples.get(i - 1).time); -            } - +    private <E extends JfrReader.Event> void readSegments(JfrReader reader, SampleCollector<E> collector, AsyncDataAggregator dataAggregator) throws IOException { +        List<E> samples = reader.readAllEvents(collector.eventClass()); +        for (E sample : samples) {              String threadName = reader.threads.get((long) sample.tid);              if (threadName == null) {                  continue;              } -            if (!threadFilter.test(threadName)) { +            if (!this.threadDumper.isThreadIncluded(sample.tid, threadName)) {                  continue;              } +            long value = collector.measure(sample); +              // parse the segment and give it to the data aggregator -            ProfileSegment segment = ProfileSegment.parseSegment(reader, sample, threadName, duration); -            dataAggregator.insertData(segment, window); +            ProfileSegment segment = ProfileSegment.parseSegment(reader, sample, threadName, value); +            dataAggregator.insertData(segment, this.window);          }      } diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/async/AsyncSampler.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/async/AsyncSampler.java index 178f055..961c3e9 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/sampler/async/AsyncSampler.java +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/async/AsyncSampler.java @@ -23,17 +23,18 @@ package me.lucko.spark.common.sampler.async;  import com.google.common.util.concurrent.ThreadFactoryBuilder;  import me.lucko.spark.common.SparkPlatform; -import me.lucko.spark.common.command.sender.CommandSender;  import me.lucko.spark.common.sampler.AbstractSampler; +import me.lucko.spark.common.sampler.SamplerMode;  import me.lucko.spark.common.sampler.SamplerSettings; -import me.lucko.spark.common.sampler.node.MergeMode; -import me.lucko.spark.common.sampler.source.ClassSourceLookup;  import me.lucko.spark.common.sampler.window.ProfilingWindowUtils;  import me.lucko.spark.common.tick.TickHook; +import me.lucko.spark.common.util.SparkThreadFactory; +import me.lucko.spark.common.ws.ViewerSocket;  import me.lucko.spark.proto.SparkSamplerProtos.SamplerData;  import java.util.concurrent.Executors;  import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture;  import java.util.concurrent.TimeUnit;  import java.util.function.IntPredicate; @@ -41,6 +42,11 @@ import java.util.function.IntPredicate;   * A sampler implementation using async-profiler.   */  public class AsyncSampler extends AbstractSampler { + +    /** Function to collect and measure samples - either execution or allocation */ +    private final SampleCollector<?> sampleCollector; + +    /** Object that provides access to the async-profiler API */      private final AsyncProfilerAccess profilerAccess;      /** Responsible for aggregating and then outputting collected sampling data */ @@ -55,12 +61,19 @@ public class AsyncSampler extends AbstractSampler {      /** The executor used for scheduling and management */      private ScheduledExecutorService scheduler; -    public AsyncSampler(SparkPlatform platform, SamplerSettings settings) { +    /** The task to send statistics to the viewer socket */ +    private ScheduledFuture<?> socketStatisticsTask; + +    public AsyncSampler(SparkPlatform platform, SamplerSettings settings, SampleCollector<?> collector) {          super(platform, settings); +        this.sampleCollector = collector;          this.profilerAccess = AsyncProfilerAccess.getInstance(platform);          this.dataAggregator = new AsyncDataAggregator(settings.threadGrouper());          this.scheduler = Executors.newSingleThreadScheduledExecutor( -                new ThreadFactoryBuilder().setNameFormat("spark-asyncsampler-worker-thread").build() +                new ThreadFactoryBuilder() +                        .setNameFormat("spark-async-sampler-worker-thread") +                        .setUncaughtExceptionHandler(SparkThreadFactory.EXCEPTION_HANDLER) +                        .build()          );      } @@ -79,17 +92,21 @@ public class AsyncSampler extends AbstractSampler {          int window = ProfilingWindowUtils.windowNow();          AsyncProfilerJob job = this.profilerAccess.startNewProfilerJob(); -        job.init(this.platform, this.interval, this.threadDumper, window, this.background); +        job.init(this.platform, this.sampleCollector, this.threadDumper, window, this.background);          job.start(); +        this.windowStatisticsCollector.recordWindowStartTime(window);          this.currentJob = job;          // rotate the sampler job to put data into a new window -        this.scheduler.scheduleAtFixedRate( -                this::rotateProfilerJob, -                ProfilingWindowUtils.WINDOW_SIZE_SECONDS, -                ProfilingWindowUtils.WINDOW_SIZE_SECONDS, -                TimeUnit.SECONDS -        ); +        boolean shouldNotRotate = this.sampleCollector instanceof SampleCollector.Allocation && ((SampleCollector.Allocation) this.sampleCollector).isLiveOnly(); +        if (!shouldNotRotate) { +            this.scheduler.scheduleAtFixedRate( +                    this::rotateProfilerJob, +                    ProfilingWindowUtils.WINDOW_SIZE_SECONDS, +                    ProfilingWindowUtils.WINDOW_SIZE_SECONDS, +                    TimeUnit.SECONDS +            ); +        }          recordInitialGcStats();          scheduleTimeout(); @@ -106,9 +123,6 @@ public class AsyncSampler extends AbstractSampler {                  try {                      // stop the previous job                      previousJob.stop(); - -                    // collect statistics for the window -                    this.windowStatisticsCollector.measureNow(previousJob.getWindow());                  } catch (Exception e) {                      e.printStackTrace();                  } @@ -116,10 +130,18 @@ public class AsyncSampler extends AbstractSampler {                  // start a new job                  int window = previousJob.getWindow() + 1;                  AsyncProfilerJob newJob = this.profilerAccess.startNewProfilerJob(); -                newJob.init(this.platform, this.interval, this.threadDumper, window, this.background); +                newJob.init(this.platform, this.sampleCollector, this.threadDumper, window, this.background);                  newJob.start(); +                this.windowStatisticsCollector.recordWindowStartTime(window);                  this.currentJob = newJob; +                // collect statistics for the previous window +                try { +                    this.windowStatisticsCollector.measureNow(previousJob.getWindow()); +                } catch (Exception e) { +                    e.printStackTrace(); +                } +                  // aggregate the output of the previous job                  previousJob.aggregate(this.dataAggregator); @@ -127,6 +149,8 @@ public class AsyncSampler extends AbstractSampler {                  IntPredicate predicate = ProfilingWindowUtils.keepHistoryBefore(window);                  this.dataAggregator.pruneData(predicate);                  this.windowStatisticsCollector.pruneStatistics(predicate); + +                this.scheduler.execute(this::processWindowRotate);              }          } catch (Throwable e) {              e.printStackTrace(); @@ -167,6 +191,10 @@ public class AsyncSampler extends AbstractSampler {              this.currentJob = null;          } +        if (this.socketStatisticsTask != null) { +            this.socketStatisticsTask.cancel(false); +        } +          if (this.scheduler != null) {              this.scheduler.shutdown();              this.scheduler = null; @@ -174,10 +202,27 @@ public class AsyncSampler extends AbstractSampler {      }      @Override -    public SamplerData toProto(SparkPlatform platform, CommandSender creator, String comment, MergeMode mergeMode, ClassSourceLookup classSourceLookup) { +    public void attachSocket(ViewerSocket socket) { +        super.attachSocket(socket); + +        if (this.socketStatisticsTask == null) { +            this.socketStatisticsTask = this.scheduler.scheduleAtFixedRate(this::sendStatisticsToSocket, 10, 10, TimeUnit.SECONDS); +        } +    } + +    @Override +    public SamplerMode getMode() { +        return this.sampleCollector.getMode(); +    } + +    @Override +    public SamplerData toProto(SparkPlatform platform, ExportProps exportProps) {          SamplerData.Builder proto = SamplerData.newBuilder(); -        writeMetadataToProto(proto, platform, creator, comment, this.dataAggregator); -        writeDataToProto(proto, this.dataAggregator, mergeMode, classSourceLookup); +        if (exportProps.channelInfo() != null) { +            proto.setChannelInfo(exportProps.channelInfo()); +        } +        writeMetadataToProto(proto, platform, exportProps.creator(), exportProps.comment(), this.dataAggregator); +        writeDataToProto(proto, this.dataAggregator, exportProps.mergeMode().get(), exportProps.classSourceLookup().get());          return proto.build();      } diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/async/ProfileSegment.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/async/ProfileSegment.java index 26debaf..0804ccf 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/sampler/async/ProfileSegment.java +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/async/ProfileSegment.java @@ -38,13 +38,13 @@ public class ProfileSegment {      /** The stack trace for this segment */      private final AsyncStackTraceElement[] stackTrace;      /** The time spent executing this segment in microseconds */ -    private final long time; +    private final long value; -    public ProfileSegment(int nativeThreadId, String threadName, AsyncStackTraceElement[] stackTrace, long time) { +    public ProfileSegment(int nativeThreadId, String threadName, AsyncStackTraceElement[] stackTrace, long value) {          this.nativeThreadId = nativeThreadId;          this.threadName = threadName;          this.stackTrace = stackTrace; -        this.time = time; +        this.value = value;      }      public int getNativeThreadId() { @@ -59,11 +59,11 @@ public class ProfileSegment {          return this.stackTrace;      } -    public long getTime() { -        return this.time; +    public long getValue() { +        return this.value;      } -    public static ProfileSegment parseSegment(JfrReader reader, JfrReader.ExecutionSample sample, String threadName, long duration) { +    public static ProfileSegment parseSegment(JfrReader reader, JfrReader.Event sample, String threadName, long value) {          JfrReader.StackTrace stackTrace = reader.stackTraces.get(sample.stackTraceId);          int len = stackTrace.methods.length; @@ -72,7 +72,7 @@ public class ProfileSegment {              stack[i] = parseStackFrame(reader, stackTrace.methods[i]);          } -        return new ProfileSegment(sample.tid, threadName, stack, duration); +        return new ProfileSegment(sample.tid, threadName, stack, value);      }      private static AsyncStackTraceElement parseStackFrame(JfrReader reader, long methodId) { diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/async/SampleCollector.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/async/SampleCollector.java new file mode 100644 index 0000000..6054b91 --- /dev/null +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/async/SampleCollector.java @@ -0,0 +1,154 @@ +/* + * This file is part of spark. + * + *  Copyright (c) lucko (Luck) <luck@lucko.me> + *  Copyright (c) contributors + * + *  This program is free software: you can redistribute it and/or modify + *  it under the terms of the GNU General Public License as published by + *  the Free Software Foundation, either version 3 of the License, or + *  (at your option) any later version. + * + *  This program is distributed in the hope that it will be useful, + *  but WITHOUT ANY WARRANTY; without even the implied warranty of + *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the + *  GNU General Public License for more details. + * + *  You should have received a copy of the GNU General Public License + *  along with this program.  If not, see <http://www.gnu.org/licenses/>. + */ + +package me.lucko.spark.common.sampler.async; + +import com.google.common.collect.ImmutableList; + +import me.lucko.spark.common.sampler.SamplerMode; +import me.lucko.spark.common.sampler.async.AsyncProfilerAccess.ProfilingEvent; +import me.lucko.spark.common.sampler.async.jfr.JfrReader.AllocationSample; +import me.lucko.spark.common.sampler.async.jfr.JfrReader.Event; +import me.lucko.spark.common.sampler.async.jfr.JfrReader.ExecutionSample; + +import java.util.Collection; +import java.util.Objects; + +/** + * Collects and processes sample events for a given type. + * + * @param <E> the event type + */ +public interface SampleCollector<E extends Event> { + +    /** +     * Gets the arguments to initialise the profiler. +     * +     * @param access the async profiler access object +     * @return the initialisation arguments +     */ +    Collection<String> initArguments(AsyncProfilerAccess access); + +    /** +     * Gets the event class processed by this sample collector. +     * +     * @return the event class +     */ +    Class<E> eventClass(); + +    /** +     * Gets the measurements for a given event +     * +     * @param event the event +     * @return the measurement +     */ +    long measure(E event); + +    /** +     * Gets the mode for the collector. +     * +     * @return the mode +     */ +    SamplerMode getMode(); + +    /** +     * Sample collector for execution (cpu time) profiles. +     */ +    final class Execution implements SampleCollector<ExecutionSample> { +        private final int interval; // time in microseconds + +        public Execution(int interval) { +            this.interval = interval; +        } + +        @Override +        public Collection<String> initArguments(AsyncProfilerAccess access) { +            ProfilingEvent event = access.getProfilingEvent(); +            Objects.requireNonNull(event, "event"); + +            return ImmutableList.of( +                    "event=" + event, +                    "interval=" + this.interval + "us" +            ); +        } + +        @Override +        public Class<ExecutionSample> eventClass() { +            return ExecutionSample.class; +        } + +        @Override +        public long measure(ExecutionSample event) { +            return event.value() * this.interval; +        } + +        @Override +        public SamplerMode getMode() { +            return SamplerMode.EXECUTION; +        } +    } + +    /** +     * Sample collector for allocation (memory) profiles. +     */ +    final class Allocation implements SampleCollector<AllocationSample> { +        private final int intervalBytes; +        private final boolean liveOnly; + +        public Allocation(int intervalBytes, boolean liveOnly) { +            this.intervalBytes = intervalBytes; +            this.liveOnly = liveOnly; +        } + +        public boolean isLiveOnly() { +            return this.liveOnly; +        } + +        @Override +        public Collection<String> initArguments(AsyncProfilerAccess access) { +            ProfilingEvent event = access.getAllocationProfilingEvent(); +            Objects.requireNonNull(event, "event"); + +            ImmutableList.Builder<String> builder = ImmutableList.builder(); +            builder.add("event=" + event); +            builder.add("alloc=" + this.intervalBytes); +            if (this.liveOnly) { +                builder.add("live"); +            } +            return builder.build(); +        } + +        @Override +        public Class<AllocationSample> eventClass() { +            return AllocationSample.class; +        } + +        @Override +        public long measure(AllocationSample event) { +            return event.value(); +        } + +        @Override +        public SamplerMode getMode() { +            return SamplerMode.ALLOCATION; +        } +    } + +} diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/java/JavaSampler.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/java/JavaSampler.java index 72a37e8..e29619b 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/sampler/java/JavaSampler.java +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/java/JavaSampler.java @@ -23,14 +23,14 @@ package me.lucko.spark.common.sampler.java;  import com.google.common.util.concurrent.ThreadFactoryBuilder;  import me.lucko.spark.common.SparkPlatform; -import me.lucko.spark.common.command.sender.CommandSender;  import me.lucko.spark.common.sampler.AbstractSampler; +import me.lucko.spark.common.sampler.SamplerMode;  import me.lucko.spark.common.sampler.SamplerSettings; -import me.lucko.spark.common.sampler.node.MergeMode; -import me.lucko.spark.common.sampler.source.ClassSourceLookup;  import me.lucko.spark.common.sampler.window.ProfilingWindowUtils;  import me.lucko.spark.common.sampler.window.WindowStatisticsCollector;  import me.lucko.spark.common.tick.TickHook; +import me.lucko.spark.common.util.SparkThreadFactory; +import me.lucko.spark.common.ws.ViewerSocket;  import me.lucko.spark.proto.SparkSamplerProtos.SamplerData;  import java.lang.management.ManagementFactory; @@ -51,12 +51,18 @@ public class JavaSampler extends AbstractSampler implements Runnable {      /** The worker pool for inserting stack nodes */      private final ScheduledExecutorService workerPool = Executors.newScheduledThreadPool( -            6, new ThreadFactoryBuilder().setNameFormat("spark-worker-" + THREAD_ID.getAndIncrement() + "-%d").build() +            6, new ThreadFactoryBuilder() +                    .setNameFormat("spark-java-sampler-" + THREAD_ID.getAndIncrement() + "-%d") +                    .setUncaughtExceptionHandler(SparkThreadFactory.EXCEPTION_HANDLER) +                    .build()      );      /** The main sampling task */      private ScheduledFuture<?> task; +    /** The task to send statistics to the viewer socket */ +    private ScheduledFuture<?> socketStatisticsTask; +      /** The thread management interface for the current JVM */      private final ThreadMXBean threadBean = ManagementFactory.getThreadMXBean(); @@ -90,6 +96,7 @@ public class JavaSampler extends AbstractSampler implements Runnable {              }          } +        this.windowStatisticsCollector.recordWindowStartTime(ProfilingWindowUtils.unixMillisToWindow(this.startTime));          this.task = this.workerPool.scheduleAtFixedRate(this, 0, this.interval, TimeUnit.MICROSECONDS);      } @@ -99,10 +106,16 @@ public class JavaSampler extends AbstractSampler implements Runnable {          this.task.cancel(false); +        if (this.socketStatisticsTask != null) { +            this.socketStatisticsTask.cancel(false); +        } +          if (!cancelled) {              // collect statistics for the final window              this.windowStatisticsCollector.measureNow(this.lastWindow.get());          } + +        this.workerPool.shutdown();      }      @Override @@ -127,6 +140,15 @@ public class JavaSampler extends AbstractSampler implements Runnable {          }      } +    @Override +    public void attachSocket(ViewerSocket socket) { +        super.attachSocket(socket); + +        if (this.socketStatisticsTask == null) { +            this.socketStatisticsTask = this.workerPool.scheduleAtFixedRate(this::sendStatisticsToSocket, 10, 10, TimeUnit.SECONDS); +        } +    } +      private final class InsertDataTask implements Runnable {          private final ThreadInfo[] threadDumps;          private final int window; @@ -149,6 +171,9 @@ public class JavaSampler extends AbstractSampler implements Runnable {              int previousWindow = JavaSampler.this.lastWindow.getAndUpdate(previous -> Math.max(this.window, previous));              if (previousWindow != 0 && previousWindow != this.window) { +                // record the start time for the new window +                JavaSampler.this.windowStatisticsCollector.recordWindowStartTime(this.window); +                  // collect statistics for the previous window                  JavaSampler.this.windowStatisticsCollector.measureNow(previousWindow); @@ -156,16 +181,25 @@ public class JavaSampler extends AbstractSampler implements Runnable {                  IntPredicate predicate = ProfilingWindowUtils.keepHistoryBefore(this.window);                  JavaSampler.this.dataAggregator.pruneData(predicate);                  JavaSampler.this.windowStatisticsCollector.pruneStatistics(predicate); + +                JavaSampler.this.workerPool.execute(JavaSampler.this::processWindowRotate);              }          }      }      @Override -    public SamplerData toProto(SparkPlatform platform, CommandSender creator, String comment, MergeMode mergeMode, ClassSourceLookup classSourceLookup) { +    public SamplerData toProto(SparkPlatform platform, ExportProps exportProps) {          SamplerData.Builder proto = SamplerData.newBuilder(); -        writeMetadataToProto(proto, platform, creator, comment, this.dataAggregator); -        writeDataToProto(proto, this.dataAggregator, mergeMode, classSourceLookup); +        if (exportProps.channelInfo() != null) { +            proto.setChannelInfo(exportProps.channelInfo()); +        } +        writeMetadataToProto(proto, platform, exportProps.creator(), exportProps.comment(), this.dataAggregator); +        writeDataToProto(proto, this.dataAggregator, exportProps.mergeMode().get(), exportProps.classSourceLookup().get());          return proto.build();      } +    @Override +    public SamplerMode getMode() { +        return SamplerMode.EXECUTION; +    }  } diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/java/TickedDataAggregator.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/java/TickedDataAggregator.java index d537b96..08cb719 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/sampler/java/TickedDataAggregator.java +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/java/TickedDataAggregator.java @@ -30,6 +30,7 @@ import me.lucko.spark.proto.SparkSamplerProtos.SamplerMetadata;  import java.lang.management.ThreadInfo;  import java.util.ArrayList;  import java.util.List; +import java.util.concurrent.Executor;  import java.util.concurrent.ExecutorService;  import java.util.concurrent.TimeUnit; @@ -75,7 +76,7 @@ public class TickedDataAggregator extends JavaDataAggregator {      public SamplerMetadata.DataAggregator getMetadata() {          // push the current tick (so numberOfTicks is accurate)          synchronized (this.mutex) { -            pushCurrentTick(); +            pushCurrentTick(Runnable::run);              this.currentData = null;          } @@ -92,7 +93,7 @@ public class TickedDataAggregator extends JavaDataAggregator {          synchronized (this.mutex) {              int tick = this.tickHook.getCurrentTick();              if (this.currentTick != tick || this.currentData == null) { -                pushCurrentTick(); +                pushCurrentTick(this.workerPool);                  this.currentTick = tick;                  this.currentData = new TickList(this.expectedSize, window);              } @@ -102,7 +103,7 @@ public class TickedDataAggregator extends JavaDataAggregator {      }      // guarded by 'mutex' -    private void pushCurrentTick() { +    private void pushCurrentTick(Executor executor) {          TickList currentData = this.currentData;          if (currentData == null) {              return; @@ -116,7 +117,7 @@ public class TickedDataAggregator extends JavaDataAggregator {              return;          } -        this.workerPool.submit(currentData); +        executor.execute(currentData);          this.tickCounter.increment();      } @@ -124,7 +125,7 @@ public class TickedDataAggregator extends JavaDataAggregator {      public List<ThreadNode> exportData() {          // push the current tick          synchronized (this.mutex) { -            pushCurrentTick(); +            pushCurrentTick(Runnable::run);          }          return super.exportData(); diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/window/ProtoTimeEncoder.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/window/ProtoTimeEncoder.java index 03da075..fb4a4fc 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/sampler/window/ProtoTimeEncoder.java +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/window/ProtoTimeEncoder.java @@ -27,18 +27,25 @@ import java.util.HashMap;  import java.util.List;  import java.util.Map;  import java.util.concurrent.atomic.LongAdder; +import java.util.function.LongToDoubleFunction;  import java.util.stream.IntStream;  /**   * Encodes a map of int->double into a double array.   */  public class ProtoTimeEncoder { + +    /** A transformer function to transform the 'time' value from a long to a double */ +    private final LongToDoubleFunction valueTransformer; +      /** A sorted array of all possible keys to encode */      private final int[] keys;      /** A map of key value -> index in the keys array */      private final Map<Integer, Integer> keysToIndex; -    public ProtoTimeEncoder(List<ThreadNode> sourceData) { +    public ProtoTimeEncoder(LongToDoubleFunction valueTransformer, List<ThreadNode> sourceData) { +        this.valueTransformer = valueTransformer; +          // get an array of all keys that show up in the source data          this.keys = sourceData.stream()                  .map(n -> n.getTimeWindows().stream().mapToInt(i -> i)) @@ -81,11 +88,8 @@ public class ProtoTimeEncoder {                  throw new RuntimeException("No index for key " + key + " in " + this.keysToIndex.keySet());              } -            // convert the duration from microseconds -> milliseconds -            double durationInMilliseconds = value.longValue() / 1000d; -              // store in the array -            array[idx] = durationInMilliseconds; +            array[idx] = this.valueTransformer.applyAsDouble(value.longValue());          });          return array; diff --git a/spark-common/src/main/java/me/lucko/spark/common/sampler/window/WindowStatisticsCollector.java b/spark-common/src/main/java/me/lucko/spark/common/sampler/window/WindowStatisticsCollector.java index ce65013..86c0b20 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/sampler/window/WindowStatisticsCollector.java +++ b/spark-common/src/main/java/me/lucko/spark/common/sampler/window/WindowStatisticsCollector.java @@ -20,29 +20,35 @@  package me.lucko.spark.common.sampler.window; +import me.lucko.spark.api.statistic.misc.DoubleAverageInfo;  import me.lucko.spark.common.SparkPlatform;  import me.lucko.spark.common.monitor.cpu.CpuMonitor;  import me.lucko.spark.common.monitor.tick.TickStatistics;  import me.lucko.spark.common.platform.world.AsyncWorldInfoProvider;  import me.lucko.spark.common.platform.world.WorldInfoProvider;  import me.lucko.spark.common.tick.TickHook; -import me.lucko.spark.common.util.RollingAverage;  import me.lucko.spark.proto.SparkProtos; +import java.util.HashMap;  import java.util.Map;  import java.util.concurrent.ConcurrentHashMap;  import java.util.concurrent.atomic.AtomicInteger;  import java.util.function.IntPredicate; +import java.util.logging.Level;  /**   * Collects statistics for each profiling window.   */  public class WindowStatisticsCollector { -    private static final SparkProtos.WindowStatistics ZERO = SparkProtos.WindowStatistics.newBuilder().build(); +    private static final SparkProtos.WindowStatistics ZERO = SparkProtos.WindowStatistics.newBuilder() +            .setDuration(ProfilingWindowUtils.WINDOW_SIZE_SECONDS * 1000) +            .build();      /** The platform */      private final SparkPlatform platform; +    /** Map of profiling window -> start time */ +    private final Map<Integer, Long> windowStartTimes = new HashMap<>();      /** Map of profiling window -> statistics */      private final Map<Integer, SparkProtos.WindowStatistics> stats; @@ -100,12 +106,21 @@ public class WindowStatisticsCollector {      }      /** +     * Records the wall-clock time when a window was started. +     * +     * @param window the window +     */ +    public void recordWindowStartTime(int window) { +        this.windowStartTimes.put(window, System.currentTimeMillis()); +    } + +    /**       * Measures statistics for the given window if none have been recorded yet.       *       * @param window the window       */      public void measureNow(int window) { -        this.stats.computeIfAbsent(window, w -> measure()); +        this.stats.computeIfAbsent(window, this::measure);      }      /** @@ -132,14 +147,25 @@ public class WindowStatisticsCollector {       *       * @return the current statistics       */ -    private SparkProtos.WindowStatistics measure() { +    private SparkProtos.WindowStatistics measure(int window) {          SparkProtos.WindowStatistics.Builder builder = SparkProtos.WindowStatistics.newBuilder(); +        long endTime = System.currentTimeMillis(); +        Long startTime = this.windowStartTimes.get(window); +        if (startTime == null) { +            this.platform.getPlugin().log(Level.WARNING, "Unknown start time for window " + window); +            startTime = endTime - (ProfilingWindowUtils.WINDOW_SIZE_SECONDS * 1000); // guess +        } + +        builder.setStartTime(startTime); +        builder.setEndTime(endTime); +        builder.setDuration((int) (endTime - startTime)); +          TickStatistics tickStatistics = this.platform.getTickStatistics();          if (tickStatistics != null) {              builder.setTps(tickStatistics.tps1Min()); -            RollingAverage mspt = tickStatistics.duration1Min(); +            DoubleAverageInfo mspt = tickStatistics.duration1Min();              if (mspt != null) {                  builder.setMsptMedian(mspt.median());                  builder.setMsptMax(mspt.max()); @@ -225,11 +251,13 @@ public class WindowStatisticsCollector {              if (this.startTick == -1) {                  throw new IllegalStateException("start tick not recorded");              } -            if (this.stopTick == -1) { -                throw new IllegalStateException("stop tick not recorded"); + +            int stopTick = this.stopTick; +            if (stopTick == -1) { +                stopTick = this.tickHook.getCurrentTick();              } -            return this.stopTick - this.startTick; +            return stopTick - this.startTick;          }      } diff --git a/spark-common/src/main/java/me/lucko/spark/common/util/BytebinClient.java b/spark-common/src/main/java/me/lucko/spark/common/util/BytebinClient.java index e69b94e..b8a2053 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/util/BytebinClient.java +++ b/spark-common/src/main/java/me/lucko/spark/common/util/BytebinClient.java @@ -32,6 +32,8 @@ import java.util.zip.GZIPOutputStream;  /**   * Utility for posting content to bytebin. + * + * @see <a href="https://github.com/lucko/bytebin">https://github.com/lucko/bytebin</a>   */  public class BytebinClient { @@ -45,7 +47,11 @@ public class BytebinClient {          this.userAgent = userAgent;      } -    private Content postContent(String contentType, Consumer<OutputStream> consumer) throws IOException { +    private Content postContent(String contentType, Consumer<OutputStream> consumer, String userAgentExtra) throws IOException { +        String userAgent = userAgentExtra != null +                ? this.userAgent + "/" + userAgentExtra +                : this.userAgent; +          URL url = new URL(this.url + "post");          HttpURLConnection connection = (HttpURLConnection) url.openConnection();          try { @@ -55,7 +61,7 @@ public class BytebinClient {              connection.setDoOutput(true);              connection.setRequestMethod("POST");              connection.setRequestProperty("Content-Type", contentType); -            connection.setRequestProperty("User-Agent", this.userAgent); +            connection.setRequestProperty("User-Agent", userAgent);              connection.setRequestProperty("Content-Encoding", "gzip");              connection.connect(); @@ -74,14 +80,18 @@ public class BytebinClient {          }      } -    public Content postContent(AbstractMessageLite<?, ?> proto, String contentType) throws IOException { +    public Content postContent(AbstractMessageLite<?, ?> proto, String contentType, String userAgentExtra) throws IOException {          return postContent(contentType, outputStream -> {              try (OutputStream out = new GZIPOutputStream(outputStream)) {                  proto.writeTo(out);              } catch (IOException e) {                  throw new RuntimeException(e);              } -        }); +        }, userAgentExtra); +    } + +    public Content postContent(AbstractMessageLite<?, ?> proto, String contentType) throws IOException { +        return postContent(proto, contentType, null);      }      public static final class Content { diff --git a/spark-common/src/main/java/me/lucko/spark/common/util/ClassFinder.java b/spark-common/src/main/java/me/lucko/spark/common/util/ClassFinder.java index 4481786..f132613 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/util/ClassFinder.java +++ b/spark-common/src/main/java/me/lucko/spark/common/util/ClassFinder.java @@ -44,6 +44,9 @@ public class ClassFinder {          } catch (Exception e) {              return;          } +        if (instrumentation == null) { +            return; +        }          // obtain and cache loaded classes          for (Class<?> loadedClass : instrumentation.getAllLoadedClasses()) { diff --git a/spark-common/src/main/java/me/lucko/spark/common/util/Configuration.java b/spark-common/src/main/java/me/lucko/spark/common/util/Configuration.java index 32f3bc6..d19ba64 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/util/Configuration.java +++ b/spark-common/src/main/java/me/lucko/spark/common/util/Configuration.java @@ -22,6 +22,7 @@ package me.lucko.spark.common.util;  import com.google.gson.Gson;  import com.google.gson.GsonBuilder; +import com.google.gson.JsonArray;  import com.google.gson.JsonElement;  import com.google.gson.JsonObject;  import com.google.gson.JsonPrimitive; @@ -32,6 +33,9 @@ import java.io.IOException;  import java.nio.charset.StandardCharsets;  import java.nio.file.Files;  import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List;  public final class Configuration {      private static final Gson GSON = new GsonBuilder().setPrettyPrinting().create(); @@ -103,6 +107,21 @@ public final class Configuration {          return val.isBoolean() ? val.getAsInt() : def;      } +    public List<String> getStringList(String path) { +        JsonElement el = this.root.get(path); +        if (el == null || !el.isJsonArray()) { +            return Collections.emptyList(); +        } + +        List<String> list = new ArrayList<>(); +        for (JsonElement child : el.getAsJsonArray()) { +            if (child.isJsonPrimitive()) { +                list.add(child.getAsJsonPrimitive().getAsString()); +            } +        } +        return list; +    } +      public void setString(String path, String value) {          this.root.add(path, new JsonPrimitive(value));      } @@ -115,6 +134,14 @@ public final class Configuration {          this.root.add(path, new JsonPrimitive(value));      } +    public void setStringList(String path, List<String> value) { +        JsonArray array = new JsonArray(); +        for (String str : value) { +            array.add(str); +        } +        this.root.add(path, array); +    } +      public boolean contains(String path) {          return this.root.has(path);      } diff --git a/spark-common/src/main/java/me/lucko/spark/common/util/MediaTypes.java b/spark-common/src/main/java/me/lucko/spark/common/util/MediaTypes.java new file mode 100644 index 0000000..2c49540 --- /dev/null +++ b/spark-common/src/main/java/me/lucko/spark/common/util/MediaTypes.java @@ -0,0 +1,29 @@ +/* + * This file is part of spark. + * + *  Copyright (c) lucko (Luck) <luck@lucko.me> + *  Copyright (c) contributors + * + *  This program is free software: you can redistribute it and/or modify + *  it under the terms of the GNU General Public License as published by + *  the Free Software Foundation, either version 3 of the License, or + *  (at your option) any later version. + * + *  This program is distributed in the hope that it will be useful, + *  but WITHOUT ANY WARRANTY; without even the implied warranty of + *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the + *  GNU General Public License for more details. + * + *  You should have received a copy of the GNU General Public License + *  along with this program.  If not, see <http://www.gnu.org/licenses/>. + */ + +package me.lucko.spark.common.util; + +public enum MediaTypes { +    ; + +    public static final String SPARK_SAMPLER_MEDIA_TYPE = "application/x-spark-sampler"; +    public static final String SPARK_HEAP_MEDIA_TYPE = "application/x-spark-heap"; + +} diff --git a/spark-common/src/main/java/me/lucko/spark/common/util/RollingAverage.java b/spark-common/src/main/java/me/lucko/spark/common/util/RollingAverage.java index 65753bc..57dfdff 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/util/RollingAverage.java +++ b/spark-common/src/main/java/me/lucko/spark/common/util/RollingAverage.java @@ -21,7 +21,6 @@  package me.lucko.spark.common.util;  import me.lucko.spark.api.statistic.misc.DoubleAverageInfo; -import me.lucko.spark.proto.SparkProtos;  import java.math.BigDecimal;  import java.math.RoundingMode; @@ -112,14 +111,4 @@ public class RollingAverage implements DoubleAverageInfo {          return sortedSamples[rank].doubleValue();      } -    public SparkProtos.RollingAverageValues toProto() { -        return SparkProtos.RollingAverageValues.newBuilder() -                .setMean(mean()) -                .setMax(max()) -                .setMin(min()) -                .setMedian(median()) -                .setPercentile95(percentile95th()) -                .build(); -    } -  } diff --git a/spark-common/src/main/java/me/lucko/spark/common/util/SparkThreadFactory.java b/spark-common/src/main/java/me/lucko/spark/common/util/SparkThreadFactory.java index 156fa0d..42dca12 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/util/SparkThreadFactory.java +++ b/spark-common/src/main/java/me/lucko/spark/common/util/SparkThreadFactory.java @@ -23,7 +23,13 @@ package me.lucko.spark.common.util;  import java.util.concurrent.ThreadFactory;  import java.util.concurrent.atomic.AtomicInteger; -public class SparkThreadFactory implements ThreadFactory, Thread.UncaughtExceptionHandler { +public class SparkThreadFactory implements ThreadFactory { + +    public static final Thread.UncaughtExceptionHandler EXCEPTION_HANDLER = (t, e) -> { +        System.err.println("Uncaught exception thrown by thread " + t.getName()); +        e.printStackTrace(); +    }; +      private static final AtomicInteger poolNumber = new AtomicInteger(1);      private final AtomicInteger threadNumber = new AtomicInteger(1);      private final String namePrefix; @@ -36,14 +42,9 @@ public class SparkThreadFactory implements ThreadFactory, Thread.UncaughtExcepti      public Thread newThread(Runnable r) {          Thread t = new Thread(r, this.namePrefix + this.threadNumber.getAndIncrement()); -        t.setUncaughtExceptionHandler(this); +        t.setUncaughtExceptionHandler(EXCEPTION_HANDLER);          t.setDaemon(true);          return t;      } -    @Override -    public void uncaughtException(Thread t, Throwable e) { -        System.err.println("Uncaught exception thrown by thread " + t.getName()); -        e.printStackTrace(); -    }  } diff --git a/spark-common/src/main/java/me/lucko/spark/common/util/StatisticFormatter.java b/spark-common/src/main/java/me/lucko/spark/common/util/StatisticFormatter.java index 22ee9bb..b488f50 100644 --- a/spark-common/src/main/java/me/lucko/spark/common/util/StatisticFormatter.java +++ b/spark-common/src/main/java/me/lucko/spark/common/util/StatisticFormatter.java @@ -22,6 +22,8 @@ package me.lucko.spark.common.util;  import com.google.common.base.Strings; +import me.lucko.spark.api.statistic.misc.DoubleAverageInfo; +  import net.kyori.adventure.text.Component;  import net.kyori.adventure.text.TextComponent;  import net.kyori.adventure.text.format.TextColor; @@ -55,7 +57,7 @@ public enum StatisticFormatter {          return text((tps > 20.0 ? "*" : "") + Math.min(Math.round(tps * 100.0) / 100.0, 20.0), color);      } -    public static TextComponent formatTickDurations(RollingAverage average) { +    public static TextComponent formatTickDurations(DoubleAverageInfo average) {          return text()                  .append(formatTickDuration(average.min()))                  .append(text('/', GRAY)) diff --git a/spark-common/src/main/java/me/lucko/spark/common/ws/CryptoAlgorithm.java b/spark-common/src/main/java/me/lucko/spark/common/ws/CryptoAlgorithm.java new file mode 100644 index 0000000..f6cf1db --- /dev/null +++ b/spark-common/src/main/java/me/lucko/spark/common/ws/CryptoAlgorithm.java @@ -0,0 +1,90 @@ +/* + * This file is part of spark. + * + *  Copyright (c) lucko (Luck) <luck@lucko.me> + *  Copyright (c) contributors + * + *  This program is free software: you can redistribute it and/or modify + *  it under the terms of the GNU General Public License as published by + *  the Free Software Foundation, either version 3 of the License, or + *  (at your option) any later version. + * + *  This program is distributed in the hope that it will be useful, + *  but WITHOUT ANY WARRANTY; without even the implied warranty of + *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the + *  GNU General Public License for more details. + * + *  You should have received a copy of the GNU General Public License + *  along with this program.  If not, see <http://www.gnu.org/licenses/>. + */ + +package me.lucko.spark.common.ws; + +import com.google.protobuf.ByteString; + +import java.security.KeyFactory; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.security.PublicKey; +import java.security.Signature; +import java.security.spec.X509EncodedKeySpec; + +/** + * An algorithm for keypair/signature cryptography. + */ +public enum CryptoAlgorithm { + +    Ed25519("Ed25519", 255, "Ed25519"), +    RSA2048("RSA", 2048, "SHA256withRSA"); + +    private final String keyAlgorithm; +    private final int keySize; +    private final String signatureAlgorithm; + +    CryptoAlgorithm(String keyAlgorithm, int keySize, String signatureAlgorithm) { +        this.keyAlgorithm = keyAlgorithm; +        this.keySize = keySize; +        this.signatureAlgorithm = signatureAlgorithm; +    } + +    public KeyPairGenerator createKeyPairGenerator() throws NoSuchAlgorithmException { +        return KeyPairGenerator.getInstance(this.keyAlgorithm); +    } + +    public KeyFactory createKeyFactory() throws NoSuchAlgorithmException { +        return KeyFactory.getInstance(this.keyAlgorithm); +    } + +    public Signature createSignature() throws NoSuchAlgorithmException { +        return Signature.getInstance(this.signatureAlgorithm); +    } + +    public KeyPair generateKeyPair() { +        try { +            KeyPairGenerator generator = createKeyPairGenerator(); +            generator.initialize(this.keySize); +            return generator.generateKeyPair(); +        } catch (Exception e) { +            throw new RuntimeException("Exception generating keypair", e); +        } +    } + +    public PublicKey decodePublicKey(byte[] bytes) throws IllegalArgumentException { +        try { +            X509EncodedKeySpec spec = new X509EncodedKeySpec(bytes); +            KeyFactory factory = createKeyFactory(); +            return factory.generatePublic(spec); +        } catch (Exception e) { +            throw new IllegalArgumentException("Exception parsing public key", e); +        } +    } + +    public PublicKey decodePublicKey(ByteString bytes) throws IllegalArgumentException { +        if (bytes == null) { +            return null; +        } +        return decodePublicKey(bytes.toByteArray()); +    } + +} diff --git a/spark-common/src/main/java/me/lucko/spark/common/ws/TrustedKeyStore.java b/spark-common/src/main/java/me/lucko/spark/common/ws/TrustedKeyStore.java new file mode 100644 index 0000000..1605a38 --- /dev/null +++ b/spark-common/src/main/java/me/lucko/spark/common/ws/TrustedKeyStore.java @@ -0,0 +1,139 @@ +/* + * This file is part of spark. + * + *  Copyright (c) lucko (Luck) <luck@lucko.me> + *  Copyright (c) contributors + * + *  This program is free software: you can redistribute it and/or modify + *  it under the terms of the GNU General Public License as published by + *  the Free Software Foundation, either version 3 of the License, or + *  (at your option) any later version. + * + *  This program is distributed in the hope that it will be useful, + *  but WITHOUT ANY WARRANTY; without even the implied warranty of + *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the + *  GNU General Public License for more details. + * + *  You should have received a copy of the GNU General Public License + *  along with this program.  If not, see <http://www.gnu.org/licenses/>. + */ + +package me.lucko.spark.common.ws; + +import me.lucko.spark.common.util.Configuration; + +import java.security.KeyPair; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.util.Base64; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; + +/** + * A store of trusted public keys. + */ +public class TrustedKeyStore { +    private static final String TRUSTED_KEYS_OPTION = "trustedKeys"; + +    /** The spark configuration */ +    private final Configuration configuration; +    /** Gets the local public/private key */ +    private final CompletableFuture<KeyPair> localKeyPair; +    /** A set of remote public keys to trust */ +    private final Set<PublicKey> remoteTrustedKeys; +    /** A mpa of pending remote public keys */ +    private final Map<String, PublicKey> remotePendingKeys = new HashMap<>(); + +    public TrustedKeyStore(Configuration configuration) { +        this.configuration = configuration; +        this.localKeyPair = CompletableFuture.supplyAsync(ViewerSocketConnection.CRYPTO::generateKeyPair); +        this.remoteTrustedKeys = new HashSet<>(); +        readTrustedKeys(); +    } + +    /** +     * Gets the local public key. +     * +     * @return the local public key +     */ +    public PublicKey getLocalPublicKey() { +        return this.localKeyPair.join().getPublic(); +    } + +    /** +     * Gets the local private key. +     * +     * @return the local private key +     */ +    public PrivateKey getLocalPrivateKey() { +        return this.localKeyPair.join().getPrivate(); +    } + +    /** +     * Checks if a remote public key is trusted +     * +     * @param publicKey the public key +     * @return if the key is trusted +     */ +    public boolean isKeyTrusted(PublicKey publicKey) { +        return publicKey != null && this.remoteTrustedKeys.contains(publicKey); +    } + +    /** +     * Adds a pending public key to be trusted in the future. +     * +     * @param clientId the client id submitting the key +     * @param publicKey the public key +     */ +    public void addPendingKey(String clientId, PublicKey publicKey) { +        this.remotePendingKeys.put(clientId, publicKey); +    } + +    /** +     * Trusts a previously submitted remote public key +     * +     * @param clientId the id of the client that submitted the key +     * @return true if the key was found and trusted +     */ +    public boolean trustPendingKey(String clientId) { +        PublicKey key = this.remotePendingKeys.remove(clientId); +        if (key == null) { +            return false; +        } + +        this.remoteTrustedKeys.add(key); +        writeTrustedKeys(); +        return true; +    } + +    /** +     * Reads trusted keys from the configuration +     */ +    private void readTrustedKeys() { +        for (String encodedKey : this.configuration.getStringList(TRUSTED_KEYS_OPTION)) { +            try { +                PublicKey publicKey = ViewerSocketConnection.CRYPTO.decodePublicKey(Base64.getDecoder().decode(encodedKey)); +                this.remoteTrustedKeys.add(publicKey); +            } catch (Exception e) { +                e.printStackTrace(); +            } +        } +    } + +    /** +     * Writes trusted keys to the configuration +     */ +    private void writeTrustedKeys() { +        List<String> encodedKeys = this.remoteTrustedKeys.stream() +                .map(key -> Base64.getEncoder().encodeToString(key.getEncoded())) +                .collect(Collectors.toList()); + +        this.configuration.setStringList(TRUSTED_KEYS_OPTION, encodedKeys); +    } + +} diff --git a/spark-common/src/main/java/me/lucko/spark/common/ws/ViewerSocket.java b/spark-common/src/main/java/me/lucko/spark/common/ws/ViewerSocket.java new file mode 100644 index 0000000..6a9c2b7 --- /dev/null +++ b/spark-common/src/main/java/me/lucko/spark/common/ws/ViewerSocket.java @@ -0,0 +1,255 @@ +/* + * This file is part of spark. + * + *  Copyright (c) lucko (Luck) <luck@lucko.me> + *  Copyright (c) contributors + * + *  This program is free software: you can redistribute it and/or modify + *  it under the terms of the GNU General Public License as published by + *  the Free Software Foundation, either version 3 of the License, or + *  (at your option) any later version. + * + *  This program is distributed in the hope that it will be useful, + *  but WITHOUT ANY WARRANTY; without even the implied warranty of + *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the + *  GNU General Public License for more details. + * + *  You should have received a copy of the GNU General Public License + *  along with this program.  If not, see <http://www.gnu.org/licenses/>. + */ + +package me.lucko.spark.common.ws; + +import com.google.protobuf.ByteString; + +import me.lucko.bytesocks.client.BytesocksClient; +import me.lucko.spark.common.SparkPlatform; +import me.lucko.spark.common.sampler.AbstractSampler; +import me.lucko.spark.common.sampler.Sampler; +import me.lucko.spark.common.sampler.window.ProfilingWindowUtils; +import me.lucko.spark.common.util.MediaTypes; +import me.lucko.spark.proto.SparkProtos; +import me.lucko.spark.proto.SparkSamplerProtos; +import me.lucko.spark.proto.SparkWebSocketProtos.ClientConnect; +import me.lucko.spark.proto.SparkWebSocketProtos.ClientPing; +import me.lucko.spark.proto.SparkWebSocketProtos.PacketWrapper; +import me.lucko.spark.proto.SparkWebSocketProtos.ServerConnectResponse; +import me.lucko.spark.proto.SparkWebSocketProtos.ServerPong; +import me.lucko.spark.proto.SparkWebSocketProtos.ServerUpdateSamplerData; +import me.lucko.spark.proto.SparkWebSocketProtos.ServerUpdateStatistics; + +import java.security.PublicKey; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; + +/** + * Represents a connection with the spark viewer. + */ +public class ViewerSocket implements ViewerSocketConnection.Listener, AutoCloseable { + +    /** Allow 60 seconds for the first client to connect */ +    private static final long SOCKET_INITIAL_TIMEOUT = TimeUnit.SECONDS.toMillis(60); + +    /** Once established, expect a ping at least once every 30 seconds */ +    private static final long SOCKET_ESTABLISHED_TIMEOUT = TimeUnit.SECONDS.toMillis(30); + +    /** The spark platform */ +    private final SparkPlatform platform; +    /** The export props to use when exporting the sampler data */ +    private final Sampler.ExportProps exportProps; +    /** The underlying connection */ +    private final ViewerSocketConnection socket; + +    private boolean closed = false; +    private final long socketOpenTime = System.currentTimeMillis(); +    private long lastPing = 0; +    private String lastPayloadId = null; + +    public ViewerSocket(SparkPlatform platform, BytesocksClient client, Sampler.ExportProps exportProps) throws Exception { +        this.platform = platform; +        this.exportProps = exportProps; +        this.socket = new ViewerSocketConnection(platform, client, this); +    } + +    private void log(String message) { +        this.platform.getPlugin().log(Level.INFO, "[Viewer - " + this.socket.getChannelId() + "] " + message); +    } + +    /** +     * Gets the initial payload to send to the viewer. +     * +     * @return the payload +     */ +    public SparkSamplerProtos.SocketChannelInfo getPayload() { +        return SparkSamplerProtos.SocketChannelInfo.newBuilder() +                .setChannelId(this.socket.getChannelId()) +                .setPublicKey(ByteString.copyFrom(this.platform.getTrustedKeyStore().getLocalPublicKey().getEncoded())) +                .build(); +    } + +    public boolean isOpen() { +        return !this.closed && this.socket.isOpen(); +    } + +    /** +     * Called each time the sampler rotates to a new window. +     * +     * @param sampler the sampler +     */ +    public void processWindowRotate(AbstractSampler sampler) { +        if (this.closed) { +            return; +        } + +        long time = System.currentTimeMillis(); +        if ((time - this.socketOpenTime) > SOCKET_INITIAL_TIMEOUT && (time - this.lastPing) > SOCKET_ESTABLISHED_TIMEOUT) { +            log("No clients have pinged for 30s, closing socket"); +            close(); +            return; +        } + +        // no clients connected yet! +        if (this.lastPing == 0) { +            return; +        } + +        try { +            SparkSamplerProtos.SamplerData samplerData = sampler.toProto(this.platform, this.exportProps); +            String key = this.platform.getBytebinClient().postContent(samplerData, MediaTypes.SPARK_SAMPLER_MEDIA_TYPE, "live").key(); +            sendUpdatedSamplerData(key); +        } catch (Exception e) { +            this.platform.getPlugin().log(Level.WARNING, "Error whilst sending updated sampler data to the socket"); +            e.printStackTrace(); +        } +    } + +    /** +     * Called when the sampler stops. +     * +     * @param sampler the sampler +     */ +    public void processSamplerStopped(AbstractSampler sampler) { +        if (this.closed) { +            return; +        } + +        close(); +    } + +    @Override +    public void close() { +        this.socket.sendPacket(builder -> builder.setServerPong(ServerPong.newBuilder() +                .setOk(false) +                .build() +        )); +        this.socket.close(); +        this.closed = true; +    } + +    @Override +    public boolean isKeyTrusted(PublicKey publicKey) { +        return this.platform.getTrustedKeyStore().isKeyTrusted(publicKey); +    } + +    /** +     * Sends a message to the socket to say that the given client is now trusted. +     * +     * @param clientId the client id +     */ +    public void sendClientTrustedMessage(String clientId) { +        this.socket.sendPacket(builder -> builder.setServerConnectResponse(ServerConnectResponse.newBuilder() +                .setClientId(clientId) +                .setState(ServerConnectResponse.State.ACCEPTED) +                .build() +        )); +    } + +    /** +     * Sends a message to the socket to indicate that updated sampler data is available +     * +     * @param payloadId the payload id of the updated data +     */ +    public void sendUpdatedSamplerData(String payloadId) { +        this.socket.sendPacket(builder -> builder.setServerUpdateSampler(ServerUpdateSamplerData.newBuilder() +                .setPayloadId(payloadId) +                .build() +        )); +        this.lastPayloadId = payloadId; +    } + +    /** +     * Sends a message to the socket with updated statistics +     * +     * @param platform the platform statistics +     * @param system the system statistics +     */ +    public void sendUpdatedStatistics(SparkProtos.PlatformStatistics platform, SparkProtos.SystemStatistics system) { +        this.socket.sendPacket(builder -> builder.setServerUpdateStatistics(ServerUpdateStatistics.newBuilder() +                .setPlatform(platform) +                .setSystem(system) +                .build() +        )); +    } + +    @Override +    public void onPacket(PacketWrapper packet, boolean verified, PublicKey publicKey) throws Exception { +        switch (packet.getPacketCase()) { +            case CLIENT_PING: +                onClientPing(packet.getClientPing(), publicKey); +                break; +            case CLIENT_CONNECT: +                onClientConnect(packet.getClientConnect(), verified, publicKey); +                break; +            default: +                throw new IllegalArgumentException("Unexpected packet: " + packet.getPacketCase()); +        } +    } + +    private void onClientPing(ClientPing packet, PublicKey publicKey) { +        this.lastPing = System.currentTimeMillis(); +        this.socket.sendPacket(builder -> builder.setServerPong(ServerPong.newBuilder() +                .setOk(!this.closed) +                .setData(packet.getData()) +                .build() +        )); +    } + +    private void onClientConnect(ClientConnect packet, boolean verified, PublicKey publicKey) { +        if (publicKey == null) { +            throw new IllegalStateException("Missing public key"); +        } + +        this.lastPing = System.currentTimeMillis(); + +        String clientId = packet.getClientId(); +        log("Client connected: clientId=" + clientId + ", keyhash=" + hashPublicKey(publicKey) + ", desc=" + packet.getDescription()); + +        ServerConnectResponse.Builder resp = ServerConnectResponse.newBuilder() +                .setClientId(clientId) +                .setSettings(ServerConnectResponse.Settings.newBuilder() +                        .setSamplerInterval(ProfilingWindowUtils.WINDOW_SIZE_SECONDS) +                        .setStatisticsInterval(10) +                        .build() +                ); + +        if (this.lastPayloadId != null) { +            resp.setLastPayloadId(this.lastPayloadId); +        } + +        if (this.closed) { +            resp.setState(ServerConnectResponse.State.REJECTED); +        } else if (verified) { +            resp.setState(ServerConnectResponse.State.ACCEPTED); +        } else { +            resp.setState(ServerConnectResponse.State.UNTRUSTED); +            this.platform.getTrustedKeyStore().addPendingKey(clientId, publicKey); +        } + +        this.socket.sendPacket(builder -> builder.setServerConnectResponse(resp.build())); +    } + +    private static String hashPublicKey(PublicKey publicKey) { +        return publicKey == null ? "null" : Integer.toHexString(publicKey.hashCode()); +    } + +} diff --git a/spark-common/src/main/java/me/lucko/spark/common/ws/ViewerSocketConnection.java b/spark-common/src/main/java/me/lucko/spark/common/ws/ViewerSocketConnection.java new file mode 100644 index 0000000..9079860 --- /dev/null +++ b/spark-common/src/main/java/me/lucko/spark/common/ws/ViewerSocketConnection.java @@ -0,0 +1,218 @@ +/* + * This file is part of spark. + * + *  Copyright (c) lucko (Luck) <luck@lucko.me> + *  Copyright (c) contributors + * + *  This program is free software: you can redistribute it and/or modify + *  it under the terms of the GNU General Public License as published by + *  the Free Software Foundation, either version 3 of the License, or + *  (at your option) any later version. + * + *  This program is distributed in the hope that it will be useful, + *  but WITHOUT ANY WARRANTY; without even the implied warranty of + *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the + *  GNU General Public License for more details. + * + *  You should have received a copy of the GNU General Public License + *  along with this program.  If not, see <http://www.gnu.org/licenses/>. + */ + +package me.lucko.spark.common.ws; + +import com.google.protobuf.ByteString; + +import me.lucko.bytesocks.client.BytesocksClient; +import me.lucko.spark.common.SparkPlatform; +import me.lucko.spark.proto.SparkWebSocketProtos.PacketWrapper; +import me.lucko.spark.proto.SparkWebSocketProtos.RawPacket; + +import java.io.IOException; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.security.Signature; +import java.util.Base64; +import java.util.function.Consumer; +import java.util.logging.Level; + +/** + * Controls a websocket connection between a spark server (the plugin/mod) and a spark client (the web viewer). + */ +public class ViewerSocketConnection implements BytesocksClient.Listener, AutoCloseable { + +    /** The protocol version */ +    public static final int VERSION_1 = 1; +    /** The crypto algorithm used to sign/verify messages sent between the server and client */ +    public static final CryptoAlgorithm CRYPTO = CryptoAlgorithm.RSA2048; + +    /** The platform */ +    private final SparkPlatform platform; +    /** The underlying listener */ +    private final Listener listener; +    /** The private key used to sign messages sent from this connection */ +    private final PrivateKey privateKey; +    /** The bytesocks socket */ +    private final BytesocksClient.Socket socket; + +    public ViewerSocketConnection(SparkPlatform platform, BytesocksClient client, Listener listener) throws Exception { +        this.platform = platform; +        this.listener = listener; +        this.privateKey = platform.getTrustedKeyStore().getLocalPrivateKey(); +        this.socket = client.createAndConnect(this); +    } + +    public interface Listener { + +        /** +         * Checks if the given public key is trusted +         * +         * @param publicKey the public key +         * @return true if trusted +         */ +        boolean isKeyTrusted(PublicKey publicKey); + +        /** +         * Handles a packet sent to the socket +         * +         * @param packet the packet that was sent +         * @param verified if the packet was signed by a trusted key +         * @param publicKey the public key the packet was signed with +         */ +        void onPacket(PacketWrapper packet, boolean verified, PublicKey publicKey) throws Exception; +    } + +    /** +     * Gets the bytesocks channel id +     * +     * @return the channel id +     */ +    public String getChannelId() { +        return this.socket.getChannelId(); +    } + +    /** +     * Gets if the underlying socket is open +     * +     * @return true if the socket is open +     */ +    public boolean isOpen() { +        return this.socket.isOpen(); +    } + +    @Override +    public void onText(CharSequence data) { +        try { +            RawPacket packet = decodeRawPacket(data); +            handleRawPacket(packet); +        } catch (Exception e) { +            this.platform.getPlugin().log(Level.WARNING, "Exception occurred while reading data from the socket"); +            e.printStackTrace(); +        } +    } + +    @Override +    public void onError(Throwable error) { +        this.platform.getPlugin().log(Level.INFO, "Socket error: " + error.getClass().getName() + " " + error.getMessage()); +        error.printStackTrace(); +    } + +    @Override +    public void onClose(int statusCode, String reason) { +        //this.platform.getPlugin().log(Level.INFO, "Socket closed with status " + statusCode + " and reason " + reason); +    } + +    /** +     * Sends a packet to the socket. +     * +     * @param packetBuilder the builder to construct the wrapper packet +     */ +    public void sendPacket(Consumer<PacketWrapper.Builder> packetBuilder) { +        PacketWrapper.Builder builder = PacketWrapper.newBuilder(); +        packetBuilder.accept(builder); +        PacketWrapper wrapper = builder.build(); + +        try { +            sendPacket(wrapper); +        } catch (Exception e) { +            this.platform.getPlugin().log(Level.WARNING, "Exception occurred while sending data to the socket"); +            e.printStackTrace(); +        } +    } + +    /** +     * Sends a packet to the socket. +     * +     * @param packet the packet to send +     */ +    private void sendPacket(PacketWrapper packet) throws Exception { +        ByteString msg = packet.toByteString(); + +        // sign the message using the server private key +        Signature sign = CRYPTO.createSignature(); +        sign.initSign(this.privateKey); +        sign.update(msg.asReadOnlyByteBuffer()); +        byte[] signature = sign.sign(); + +        sendRawPacket(RawPacket.newBuilder() +                .setVersion(VERSION_1) +                .setSignature(ByteString.copyFrom(signature)) +                .setMessage(msg) +                .build() +        ); +    } + +    /** +     * Sends a raw packet to the socket. +     * +     * @param packet the packet to send +     */ +    private void sendRawPacket(RawPacket packet) throws IOException { +        byte[] buf = packet.toByteArray(); +        String encoded = Base64.getEncoder().encodeToString(buf); +        this.socket.send(encoded); +    } + +    /** +     * Decodes a raw packet sent to the socket. +     * +     * @param data the encoded data +     * @return the decoded packet +     */ +    private RawPacket decodeRawPacket(CharSequence data) throws IOException { +        byte[] buf = Base64.getDecoder().decode(data.toString()); +        return RawPacket.parseFrom(buf); +    } + +    /** +     * Handles a raw packet sent to the socket +     * +     * @param packet the packet +     */ +    private void handleRawPacket(RawPacket packet) throws Exception { +        int version = packet.getVersion(); +        if (version != VERSION_1) { +            throw new IllegalArgumentException("Unsupported packet version " + version); +        } + +        ByteString message = packet.getMessage(); +        PublicKey publicKey = CRYPTO.decodePublicKey(packet.getPublicKey()); +        ByteString signature = packet.getSignature(); + +        boolean verified = false; +        if (signature != null && publicKey != null && this.listener.isKeyTrusted(publicKey)) { +            Signature sign = CRYPTO.createSignature(); +            sign.initVerify(publicKey); +            sign.update(message.asReadOnlyByteBuffer()); + +            verified = sign.verify(signature.toByteArray()); +        } + +        PacketWrapper wrapper = PacketWrapper.parseFrom(message); +        this.listener.onPacket(wrapper, verified, publicKey); +    } + +    @Override +    public void close() { +        this.socket.close(1001 /* going away */, "spark plugin disconnected"); +    } +} | 
