diff --git a/engines/python/src/main/java/ai/djl/python/engine/DsEngineProvider.java b/engines/python/src/main/java/ai/djl/python/engine/MpiEngineProvider.java similarity index 64% rename from engines/python/src/main/java/ai/djl/python/engine/DsEngineProvider.java rename to engines/python/src/main/java/ai/djl/python/engine/MpiEngineProvider.java index b098a83c3..27550fb99 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/DsEngineProvider.java +++ b/engines/python/src/main/java/ai/djl/python/engine/MpiEngineProvider.java @@ -15,17 +15,17 @@ import ai.djl.engine.EngineProvider; /** {@code DsEngineProvider} is the DeepSpeed implementation of {@link EngineProvider}. */ -public class DsEngineProvider extends PyEngineProvider { +public class MpiEngineProvider extends PyEngineProvider { - /** Constructs a new {@code DsEngineProvider} instance. */ - public DsEngineProvider() { + /** Constructs a new {@code MpiEngineProvider} instance. */ + public MpiEngineProvider() { mpiMode = true; } /** {@inheritDoc} */ @Override public String getEngineName() { - return "DeepSpeed"; + return "MPI"; } /** {@inheritDoc} */ @@ -34,8 +34,8 @@ public int getEngineRank() { return PyEngine.RANK + 1; } - /** {@code FtEngineProvider} is the alias of {@link DsEngineProvider}. */ - public static final class FtEngineProvider extends DsEngineProvider { + /** {@code FtEngineProvider} is the alias of {@link MpiEngineProvider}. */ + public static final class FtEngineProvider extends MpiEngineProvider { /** {@inheritDoc} */ @Override @@ -43,4 +43,14 @@ public String getEngineName() { return "FasterTransformer"; } } + + /** {@code DsEngineProvider} is the alias of {@link MpiEngineProvider}. */ + public static final class DsEngineProvider extends MpiEngineProvider { + + /** {@inheritDoc} */ + @Override + public String getEngineName() { + return "DeepSpeed"; + } + } } diff --git a/engines/python/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider b/engines/python/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider index ce1c3c0d3..d4c78eefa 100644 --- a/engines/python/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider +++ b/engines/python/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider @@ -1,3 +1,4 @@ ai.djl.python.engine.PyEngineProvider -ai.djl.python.engine.DsEngineProvider -ai.djl.python.engine.DsEngineProvider$FtEngineProvider +ai.djl.python.engine.MpiEngineProvider +ai.djl.python.engine.MpiEngineProvider$FtEngineProvider +ai.djl.python.engine.MpiEngineProvider$DsEngineProvider