Package oracle.pgx.api.mllib
Class GraphWiseModel<Config extends GraphWiseModelConfig,Metadata extends oracle.pgx.api.internal.mllib.GraphWiseModelMetadata<Config>,ModelType extends GraphWiseModel<Config,Metadata,ModelType>>
- java.lang.Object
-
- oracle.pgx.api.mllib.Model<ModelType>
-
- oracle.pgx.api.mllib.GraphWiseModel<Config,Metadata,ModelType>
-
- All Implemented Interfaces:
java.lang.AutoCloseable
- Direct Known Subclasses:
SupervisedGraphWiseModel
,UnsupervisedGraphWiseModel
public abstract class GraphWiseModel<Config extends GraphWiseModelConfig,Metadata extends oracle.pgx.api.internal.mllib.GraphWiseModelMetadata<Config>,ModelType extends GraphWiseModel<Config,Metadata,ModelType>> extends Model<ModelType>
Base class for GraphWiseModels- Since:
- 19.4
-
-
Constructor Summary
Constructors Constructor Description GraphWiseModel(PgxSession session, oracle.pgx.api.internal.Core core, java.util.function.Supplier<java.lang.String> keystorePathSupplier, java.util.function.Supplier<char[]> keystorePasswordSupplier, Metadata modelMetadata, java.util.function.BiFunction<PgxSession,oracle.pgx.api.internal.Graph,PgxGraph> graphConstructor)
This constructor should never be used to get a model.
-
Method Summary
All Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description void
destroy()
Blocking version ofdestroyAsync()
.PgxFuture<java.lang.Void>
destroyAsync()
Destroys a GraphWise modeldouble
fit(PgxGraph graph)
Blocking version offitAsync(PgxGraph)
.double
fit(PgxGraph trainGraph, PgxGraph valGraph)
Blocking version offitAsync(PgxGraph, PgxGraph)
.abstract PgxFuture<java.lang.Double>
fitAsync(PgxGraph graph)
Trains the GraphWise model on the input graph.abstract PgxFuture<java.lang.Double>
fitAsync(PgxGraph trainGraph, PgxGraph valGraph)
Trains the GraphWise model on the input trainGraph and evaluate on the input valGraph.int
getBatchSize()
Gets the batch sizeConfig
getConfig()
Gets the model configuration objectGraphWiseBaseConvLayerConfig[]
getConvLayerConfigs()
Gets the configuration objects for the convolutional layersint
getEdgeInputFeatureDim()
Gets the edges input feature dimension, that is, the dimension of all the input vertex properties when concatenatedjava.util.List<java.lang.String>
getEdgeInputPropertyNames()
Gets the edges input feature namesint
getEmbeddingDim()
Gets the dimension of the embeddingsint
getInputFeatureDim()
Gets the input feature dimension, that is, the dimension of all the input vertex properties when concatenateddouble
getLearningRate()
Gets the initial learning rateint
getNumEpochs()
Gets the number of epochs to train the modeljava.lang.Integer
getSeed()
Gets the random seedPgxFrame
getTrainingLog()
Blocking version ofgetTrainingLogAsync()
.abstract PgxFuture<PgxFrame>
getTrainingLogAsync()
Gets the training log that has evaluation results from validation.double
getTrainingLoss()
Gets the final training lossjava.util.List<java.lang.String>
getVertexInputPropertyNames()
Gets the vertices input feature names<ID> PgxFrame
inferEmbeddings(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
Blocking version ofinferEmbeddingsAsync(PgxGraph, Iterable)
.abstract <ID> PgxFuture<PgxFrame>
inferEmbeddingsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
Infers the embeddings for the specified vertices.boolean
isFitted()
Checks if the model is fitted
-
-
-
Constructor Detail
-
GraphWiseModel
public GraphWiseModel(PgxSession session, oracle.pgx.api.internal.Core core, java.util.function.Supplier<java.lang.String> keystorePathSupplier, java.util.function.Supplier<char[]> keystorePasswordSupplier, Metadata modelMetadata, java.util.function.BiFunction<PgxSession,oracle.pgx.api.internal.Graph,PgxGraph> graphConstructor)
This constructor should never be used to get a model. UseSupervisedGraphWiseModelBuilder
instead.- Parameters:
session
- PgxSession to which the model is connectedcore
- Core to which the model is connectedmodelMetadata
- Metadata concerning the different hyper-parameters of the GraphWise Model- Since:
- 19.4
-
-
Method Detail
-
destroyAsync
public PgxFuture<java.lang.Void> destroyAsync()
Destroys a GraphWise model- Specified by:
destroyAsync
in classModel<ModelType extends GraphWiseModel<Config,Metadata,ModelType>>
- Returns:
- a future which will be completed once the destruction request finishes.
- Since:
- 19.4
-
destroy
public void destroy() throws java.util.concurrent.ExecutionException, java.lang.InterruptedException
Blocking version ofdestroyAsync()
. CallsdestroyAsync()
and waits for the returnedPgxFuture
to complete.- Throws:
java.lang.InterruptedException
- if the caller thread gets interrupted while waiting for completion.java.util.concurrent.ExecutionException
- if any exception occurred during asynchronous execution. The actual exception will be nested.- Since:
- 19.4
-
getNumEpochs
public int getNumEpochs()
Gets the number of epochs to train the model- Returns:
- number of epochs to train the model
- Since:
- 19.4
-
getLearningRate
public double getLearningRate()
Gets the initial learning rate- Returns:
- initial learning rate
- Since:
- 19.4
-
getBatchSize
public int getBatchSize()
Gets the batch size- Returns:
- batch size
- Since:
- 19.4
-
getEmbeddingDim
public int getEmbeddingDim()
Gets the dimension of the embeddings- Returns:
- embedding dimension
- Since:
- 19.4
-
getSeed
public java.lang.Integer getSeed()
Gets the random seed- Returns:
- random seed
- Since:
- 19.4
-
getConvLayerConfigs
public GraphWiseBaseConvLayerConfig[] getConvLayerConfigs()
Gets the configuration objects for the convolutional layers- Returns:
- configurations
- Since:
- 19.4
-
getVertexInputPropertyNames
public java.util.List<java.lang.String> getVertexInputPropertyNames()
Gets the vertices input feature names- Returns:
- vertices input feature names
- Since:
- 19.4
-
getEdgeInputPropertyNames
public java.util.List<java.lang.String> getEdgeInputPropertyNames()
Gets the edges input feature names- Returns:
- edges input feature names
- Since:
- 21.2
-
isFitted
public boolean isFitted()
Checks if the model is fitted- Returns:
- true if the model is fitted
- Since:
- 19.4
-
getTrainingLoss
public double getTrainingLoss()
Gets the final training loss- Returns:
- training loss
- Since:
- 19.4
-
getInputFeatureDim
public int getInputFeatureDim()
Gets the input feature dimension, that is, the dimension of all the input vertex properties when concatenated- Returns:
- input feature dimension
- Since:
- 19.4
-
getEdgeInputFeatureDim
public int getEdgeInputFeatureDim()
Gets the edges input feature dimension, that is, the dimension of all the input vertex properties when concatenated- Returns:
- edges input feature dimension
- Since:
- 21.2
-
getConfig
public Config getConfig()
Gets the model configuration object- Returns:
- model configuration
- Since:
- 19.4
-
getTrainingLogAsync
public abstract PgxFuture<PgxFrame> getTrainingLogAsync()
Gets the training log that has evaluation results from validation. It is available only after the model was trained with validation.- Returns:
- training log
- Since:
- 24.2
-
getTrainingLog
public PgxFrame getTrainingLog() throws java.util.concurrent.ExecutionException, java.lang.InterruptedException
Blocking version ofgetTrainingLogAsync()
. CallsgetTrainingLogAsync()
and waits for the returnedPgxFuture
to complete.- Returns:
- training log
- Throws:
java.lang.InterruptedException
- if the caller thread gets interrupted while waiting for completion.java.util.concurrent.ExecutionException
- if any exception occurred during asynchronous execution. The actual exception will be nested.- Since:
- 24.2
-
fitAsync
public abstract PgxFuture<java.lang.Double> fitAsync(PgxGraph graph)
Trains the GraphWise model on the input graph.- Parameters:
graph
- input graph to fit on.- Since:
- 19.4
-
fitAsync
public abstract PgxFuture<java.lang.Double> fitAsync(PgxGraph trainGraph, PgxGraph valGraph)
Trains the GraphWise model on the input trainGraph and evaluate on the input valGraph.- Parameters:
trainGraph
- input graph to fit on.valGraph
- input graph to evaluate on for validation.- Since:
- 24.2
-
fit
public double fit(PgxGraph graph) throws java.util.concurrent.ExecutionException, java.lang.InterruptedException
Blocking version offitAsync(PgxGraph)
. CallsfitAsync(PgxGraph)
and waits for the returnedPgxFuture
to complete.- Parameters:
graph
- input graph to fit on.- Returns:
- the training loss of the last batch
- Throws:
java.lang.InterruptedException
- if the caller thread gets interrupted while waiting for completion.java.util.concurrent.ExecutionException
- if any exception occurred during asynchronous execution. The actual exception will be nested.- Since:
- 19.4
-
fit
public double fit(PgxGraph trainGraph, PgxGraph valGraph) throws java.util.concurrent.ExecutionException, java.lang.InterruptedException
Blocking version offitAsync(PgxGraph, PgxGraph)
. CallsfitAsync(PgxGraph, PgxGraph)
and waits for the returnedPgxFuture
to complete.- Parameters:
trainGraph
- input graph to fit on.valGraph
- input graph to evaluate on for validation.- Returns:
- the training loss of the last batch
- Throws:
java.lang.InterruptedException
- if the caller thread gets interrupted while waiting for completion.java.util.concurrent.ExecutionException
- if any exception occurred during asynchronous execution. The actual exception will be nested.- Since:
- 24.2
-
inferEmbeddingsAsync
public abstract <ID> PgxFuture<PgxFrame> inferEmbeddingsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
Infers the embeddings for the specified vertices.- Returns:
- PgxFrame containing the embeddings for each vertex.
- Since:
- 19.4
-
inferEmbeddings
public <ID> PgxFrame inferEmbeddings(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
Blocking version ofinferEmbeddingsAsync(PgxGraph, Iterable)
. Infers the embeddings for the specified vertices.- Returns:
- PgxFrame containing the embeddings for each vertex.
- Since:
- 19.4
-
-