Class GraphWiseModel<Config extends GraphWiseModelConfig,​Metadata extends oracle.pgx.api.internal.mllib.GraphWiseModelMetadata<Config>,​ModelType extends GraphWiseModel<Config,​Metadata,​ModelType>>

  • All Implemented Interfaces:
    java.lang.AutoCloseable
    Direct Known Subclasses:
    SupervisedGraphWiseModel, UnsupervisedGraphWiseModel

    public abstract class GraphWiseModel<Config extends GraphWiseModelConfig,​Metadata extends oracle.pgx.api.internal.mllib.GraphWiseModelMetadata<Config>,​ModelType extends GraphWiseModel<Config,​Metadata,​ModelType>>
    extends Model<ModelType>
    Base class for GraphWiseModels
    Since:
    19.4
    • Constructor Detail

      • GraphWiseModel

        public GraphWiseModel​(PgxSession session,
                              oracle.pgx.api.internal.Core core,
                              java.util.function.Supplier<java.lang.String> keystorePathSupplier,
                              java.util.function.Supplier<char[]> keystorePasswordSupplier,
                              Metadata modelMetadata,
                              java.util.function.BiFunction<PgxSession,​oracle.pgx.api.internal.Graph,​PgxGraph> graphConstructor)
        This constructor should never be used to get a model. Use SupervisedGraphWiseModelBuilder instead.
        Parameters:
        session - PgxSession to which the model is connected
        core - Core to which the model is connected
        modelMetadata - Metadata concerning the different hyper-parameters of the GraphWise Model
        Since:
        19.4
    • Method Detail

      • destroy

        public void destroy()
                     throws java.util.concurrent.ExecutionException,
                            java.lang.InterruptedException
        Blocking version of destroyAsync(). Calls destroyAsync() and waits for the returned PgxFuture to complete.
        Throws:
        java.lang.InterruptedException - if the caller thread gets interrupted while waiting for completion.
        java.util.concurrent.ExecutionException - if any exception occurred during asynchronous execution. The actual exception will be nested.
        Since:
        19.4
      • getNumEpochs

        public int getNumEpochs()
        Gets the number of epochs to train the model
        Returns:
        number of epochs to train the model
        Since:
        19.4
      • getLearningRate

        public double getLearningRate()
        Gets the initial learning rate
        Returns:
        initial learning rate
        Since:
        19.4
      • getBatchSize

        public int getBatchSize()
        Gets the batch size
        Returns:
        batch size
        Since:
        19.4
      • getEmbeddingDim

        public int getEmbeddingDim()
        Gets the dimension of the embeddings
        Returns:
        embedding dimension
        Since:
        19.4
      • getSeed

        public java.lang.Integer getSeed()
        Gets the random seed
        Returns:
        random seed
        Since:
        19.4
      • getConvLayerConfigs

        public GraphWiseBaseConvLayerConfig[] getConvLayerConfigs()
        Gets the configuration objects for the convolutional layers
        Returns:
        configurations
        Since:
        19.4
      • getVertexInputPropertyNames

        public java.util.List<java.lang.String> getVertexInputPropertyNames()
        Gets the vertices input feature names
        Returns:
        vertices input feature names
        Since:
        19.4
      • getEdgeInputPropertyNames

        public java.util.List<java.lang.String> getEdgeInputPropertyNames()
        Gets the edges input feature names
        Returns:
        edges input feature names
        Since:
        21.2
      • isFitted

        public boolean isFitted()
        Checks if the model is fitted
        Returns:
        true if the model is fitted
        Since:
        19.4
      • getTrainingLoss

        public double getTrainingLoss()
        Gets the final training loss
        Returns:
        training loss
        Since:
        19.4
      • getInputFeatureDim

        public int getInputFeatureDim()
        Gets the input feature dimension, that is, the dimension of all the input vertex properties when concatenated
        Returns:
        input feature dimension
        Since:
        19.4
      • getEdgeInputFeatureDim

        public int getEdgeInputFeatureDim()
        Gets the edges input feature dimension, that is, the dimension of all the input vertex properties when concatenated
        Returns:
        edges input feature dimension
        Since:
        21.2
      • getConfig

        public Config getConfig()
        Gets the model configuration object
        Returns:
        model configuration
        Since:
        19.4
      • getTrainingLogAsync

        public abstract PgxFuture<PgxFrame> getTrainingLogAsync()
        Gets the training log that has evaluation results from validation. It is available only after the model was trained with validation.
        Returns:
        training log
        Since:
        24.2
      • getTrainingLog

        public PgxFrame getTrainingLog()
                                throws java.util.concurrent.ExecutionException,
                                       java.lang.InterruptedException
        Blocking version of getTrainingLogAsync(). Calls getTrainingLogAsync() and waits for the returned PgxFuture to complete.
        Returns:
        training log
        Throws:
        java.lang.InterruptedException - if the caller thread gets interrupted while waiting for completion.
        java.util.concurrent.ExecutionException - if any exception occurred during asynchronous execution. The actual exception will be nested.
        Since:
        24.2
      • fitAsync

        public abstract PgxFuture<java.lang.Double> fitAsync​(PgxGraph graph)
        Trains the GraphWise model on the input graph.
        Parameters:
        graph - input graph to fit on.
        Since:
        19.4
      • fitAsync

        public abstract PgxFuture<java.lang.Double> fitAsync​(PgxGraph trainGraph,
                                                             PgxGraph valGraph)
        Trains the GraphWise model on the input trainGraph and evaluate on the input valGraph.
        Parameters:
        trainGraph - input graph to fit on.
        valGraph - input graph to evaluate on for validation.
        Since:
        24.2
      • fit

        public double fit​(PgxGraph graph)
                   throws java.util.concurrent.ExecutionException,
                          java.lang.InterruptedException
        Blocking version of fitAsync(PgxGraph). Calls fitAsync(PgxGraph) and waits for the returned PgxFuture to complete.
        Parameters:
        graph - input graph to fit on.
        Returns:
        the training loss of the last batch
        Throws:
        java.lang.InterruptedException - if the caller thread gets interrupted while waiting for completion.
        java.util.concurrent.ExecutionException - if any exception occurred during asynchronous execution. The actual exception will be nested.
        Since:
        19.4
      • fit

        public double fit​(PgxGraph trainGraph,
                          PgxGraph valGraph)
                   throws java.util.concurrent.ExecutionException,
                          java.lang.InterruptedException
        Blocking version of fitAsync(PgxGraph, PgxGraph). Calls fitAsync(PgxGraph, PgxGraph) and waits for the returned PgxFuture to complete.
        Parameters:
        trainGraph - input graph to fit on.
        valGraph - input graph to evaluate on for validation.
        Returns:
        the training loss of the last batch
        Throws:
        java.lang.InterruptedException - if the caller thread gets interrupted while waiting for completion.
        java.util.concurrent.ExecutionException - if any exception occurred during asynchronous execution. The actual exception will be nested.
        Since:
        24.2
      • inferEmbeddingsAsync

        public abstract <ID> PgxFuture<PgxFrame> inferEmbeddingsAsync​(PgxGraph graph,
                                                                      java.lang.Iterable<PgxVertex<ID>> vertices)
        Infers the embeddings for the specified vertices.
        Returns:
        PgxFrame containing the embeddings for each vertex.
        Since:
        19.4