diff --git a/jupyter-platform/platform-api/pom.xml b/jupyter-platform/platform-api/pom.xml index 578a4a4..894dfe7 100644 --- a/jupyter-platform/platform-api/pom.xml +++ b/jupyter-platform/platform-api/pom.xml @@ -149,6 +149,11 @@ guava 29.0-jre + + com.hierynomus + sshj + 0.32.0 + diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/Config.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/Config.java index 0b9bc23..3ce7257 100644 --- a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/Config.java +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/Config.java @@ -76,9 +76,18 @@ public void addCorsMappings(CorsRegistry registry) { public void addInterceptors(InterceptorRegistry registry) { if (superUserMode) { - registry.addInterceptor(authenticator).excludePathPatterns("/api/archive/**").excludePathPatterns("/api/admin/**"); + registry.addInterceptor(authenticator) + .excludePathPatterns("/api/archive/**") + .excludePathPatterns("/api/remote/run/**") + .excludePathPatterns("/api/admin/**") + .excludePathPatterns("/api/job/**") + .excludePathPatterns("/error"); } else { - registry.addInterceptor(authenticator).excludePathPatterns("/api/archive/**"); + registry.addInterceptor(authenticator) + .excludePathPatterns("/api/archive/**") + .excludePathPatterns("/api/remote/run/**") + .excludePathPatterns("/api/job/**") + .excludePathPatterns("/error"); } } } diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/controller/ArchiveController.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/controller/ArchiveController.java index 8d636ee..3e9db79 100644 --- a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/controller/ArchiveController.java +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/controller/ArchiveController.java @@ -19,17 +19,30 @@ import org.apache.airavata.jupyter.api.entity.ArchiveEntity; import org.apache.airavata.jupyter.api.entity.NotebookEntity; +import org.apache.airavata.jupyter.api.entity.job.JobEntity; import org.apache.airavata.jupyter.api.repo.ArchiveRepository; +import org.apache.airavata.jupyter.api.repo.JobRepository; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody; +import javax.servlet.http.HttpServletResponse; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.*; +import java.util.zip.ZipEntry; +import java.util.zip.ZipOutputStream; @RestController @RequestMapping(path = "/api/archive") @@ -41,6 +54,9 @@ public class ArchiveController { @Autowired private ArchiveRepository archiveRepository; + @Autowired + private JobRepository jobRepository; + @PostMapping(path = "/", consumes = "application/json", produces = "application/json") public ArchiveEntity createArchive(@RequestBody ArchiveEntity archiveEntity) { ArchiveEntity saved = archiveRepository.save(archiveEntity); @@ -70,4 +86,43 @@ public Map singleFileUpload(@RequestParam("file") MultipartFile return Collections.singletonMap("path", path.toAbsolutePath().toString()); } + + @GetMapping (value = "/download/{jobId}", produces = MediaType.APPLICATION_JSON_VALUE) + public ResponseEntity download(@PathVariable String jobId, final HttpServletResponse response) + throws Exception { + response.setContentType("application/zip"); + response.setHeader( + "Content-Disposition", + "attachment;filename=REMOTE_STATE.zip"); + + StreamingResponseBody stream = out -> { + + + Optional jobOp = jobRepository.findById(jobId); + JobEntity job = jobOp.get(); + final File directory = new File(job.getLocalWorkingPath()); + final ZipOutputStream zipOut = new ZipOutputStream(response.getOutputStream()); + + if(directory.exists() && directory.isDirectory()) { + try { + for (final File file : directory.listFiles()) { + final InputStream inputStream=new FileInputStream(file); + final ZipEntry zipEntry = new ZipEntry(file.getName()); + zipOut.putNextEntry(zipEntry); + byte[] bytes=new byte[1024]; + int length; + while ((length=inputStream.read(bytes)) >= 0) { + zipOut.write(bytes, 0, length); + } + inputStream.close(); + } + zipOut.close(); + } catch (final IOException e) { + logger.error("Exception while reading and streaming data {} ", e); + } + } + }; + logger.info("steaming response {} ", stream); + return new ResponseEntity(stream, HttpStatus.OK); + } } diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/controller/JobController.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/controller/JobController.java new file mode 100644 index 0000000..c5ab55e --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/controller/JobController.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.airavata.jupyter.api.controller; + +import org.apache.airavata.jupyter.api.entity.job.JobStatusEntity; +import org.apache.airavata.jupyter.api.repo.JobRepository; +import org.apache.airavata.jupyter.api.repo.JobStatusRepository; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import java.util.Optional; + +@RestController +@RequestMapping(path = "/api/job") +public class JobController { + + @Autowired + private JobStatusRepository jobStatusRepository; + + @Autowired + private JobRepository jobRepository; + + @GetMapping(path = "/status/{jobId}") + public JobStatusEntity getJobStatus(@PathVariable String jobId) throws Exception { + Optional jobSt = jobStatusRepository.findFirstByJobIdOrderByUpdatedTimeAsc(jobId); + return jobSt.orElseThrow(() -> new Exception("Could not find job status for job id " + jobId)); + } +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/controller/RemoteExecController.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/controller/RemoteExecController.java new file mode 100644 index 0000000..10be1f5 --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/controller/RemoteExecController.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.airavata.jupyter.api.controller; + +import org.apache.airavata.jupyter.api.entity.ArchiveEntity; +import org.apache.airavata.jupyter.api.entity.interfacing.LocalInterfaceEntity; +import org.apache.airavata.jupyter.api.entity.interfacing.SSHInterfaceEntity; +import org.apache.airavata.jupyter.api.entity.remote.ComputeEntity; +import org.apache.airavata.jupyter.api.repo.*; +import org.apache.airavata.jupyter.api.util.remote.interfacing.InterfacingProtocol; +import org.apache.airavata.jupyter.api.util.remote.interfacing.LocalInterfacingProtocol; +import org.apache.airavata.jupyter.api.util.remote.interfacing.SSHInterfacingProtocol; +import org.apache.airavata.jupyter.api.util.remote.submitters.ForkJobSubmitter; +import org.apache.airavata.jupyter.api.util.remote.submitters.JobSubmitter; +import org.apache.airavata.jupyter.api.util.remote.submitters.SlurmJobSubmitter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.security.core.Authentication; +import org.springframework.web.bind.annotation.*; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Optional; +import java.util.UUID; + +@RestController +@RequestMapping(path = "/api/remote") +public class RemoteExecController { + private static final Logger logger = LoggerFactory.getLogger(RemoteExecController.class); + + private String localWorkingDir = "/tmp"; + + @Autowired + private ComputeRepository computeRepository; + + @Autowired + private LocalInterfaceRepository localInterfaceRepository; + + @Autowired + private SSHInterfaceRepository sshInterfaceRepository; + + @Autowired + private ArchiveRepository archiveRepository; + + @Autowired + private JobRepository jobRepository; + + @Autowired + private JobStatusRepository jobStatusRepository; + + public class RunCellResponse { + private String jobId; + + public String getJobId() { + return jobId; + } + + public void setJobId(String jobId) { + this.jobId = jobId; + } + } + + @GetMapping(path = "/run/{computeId}/{archiveId}/{sessionId}") + public RunCellResponse runCell(@PathVariable String computeId, @PathVariable String archiveId, @PathVariable String sessionId) throws Exception { + + logger.info("Running cell for compute {} with state archive uploaded in to archive {}", computeId, archiveId); + + Optional archiveOp = archiveRepository.findById(archiveId); + Optional computeOp = computeRepository.findById(computeId); + if (computeOp.isPresent() && archiveOp.isPresent()) { + ComputeEntity computeEntity = computeOp.get(); + ArchiveEntity archiveEntity = archiveOp.get(); + InterfacingProtocol interfacingProtocol = resolveInterface(computeEntity.getInterfaceType(), computeEntity.getInterfaceId()); + + // Creating local working directory + String workDirForCurrent = localWorkingDir + "/" + UUID.randomUUID().toString(); + Files.createDirectory(Path.of(workDirForCurrent)); + + JobSubmitter jobSubmitter = resolveJobSubmitter(interfacingProtocol, computeEntity.getSubmitterType(), workDirForCurrent); + String jobId = jobSubmitter.submitJob(archiveEntity.getPath(), sessionId); + RunCellResponse response = new RunCellResponse(); + response.setJobId(jobId); + return response; + } else { + throw new Exception("Could not find a compute resource with id " + computeId + " or archive with id " + archiveId); + } + } + + @PostMapping(path = "/interface/local", consumes = "application/json", produces = "application/json") + public LocalInterfaceEntity createLocalInterface(Authentication authentication, @RequestBody LocalInterfaceEntity localInterfaceEntity) { + LocalInterfaceEntity saved = localInterfaceRepository.save(localInterfaceEntity); + return saved; + } + + @PostMapping(path = "/interface/ssh", consumes = "application/json", produces = "application/json") + public SSHInterfaceEntity createSSHInterface(Authentication authentication, @RequestBody SSHInterfaceEntity sshInterfaceEntity) { + SSHInterfaceEntity saved = sshInterfaceRepository.save(sshInterfaceEntity); + return saved; + } + + @PostMapping(path = "/compute", consumes = "application/json", produces = "application/json") + public ComputeEntity createCompute(Authentication authentication, @RequestBody ComputeEntity computeEntity) { + ComputeEntity saved = computeRepository.save(computeEntity); + return saved; + } + + private JobSubmitter resolveJobSubmitter(InterfacingProtocol interfacingProtocol, + ComputeEntity.SubmitterType submitterType, + String workDir) throws Exception { + switch (submitterType) { + case FORK: + return new ForkJobSubmitter(workDir, interfacingProtocol, jobRepository, jobStatusRepository); + case SLURM: + return new SlurmJobSubmitter(); + } + + throw new Exception("Could not find a job submitter with type " + submitterType.name()); + + } + + private InterfacingProtocol resolveInterface(ComputeEntity.InterfaceType interfaceType, String interfaceId) throws Exception { + + switch (interfaceType) { + case LOCAL: + Optional localInterfaceOp = localInterfaceRepository.findById(interfaceId); + if (localInterfaceOp.isPresent()) { + return new LocalInterfacingProtocol(localInterfaceOp.get().getWorkingDirectory()); + } else { + throw new Exception("Could not find a local interface with id " + interfaceId); + } + case SSH: + Optional sshInterfaceOp = sshInterfaceRepository.findById(interfaceId); + if (sshInterfaceOp.isPresent()) { + return new SSHInterfacingProtocol(sshInterfaceOp.get(), sshInterfaceOp.get().getWorkingDirectory()); + } else { + throw new Exception("Could not find a SSH interface with id " + interfaceId); + } + } + + throw new Exception("Could not find a valid interface for type " + interfaceType.name() + " and id " + interfaceId); + } +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/entity/interfacing/LocalInterfaceEntity.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/entity/interfacing/LocalInterfaceEntity.java new file mode 100644 index 0000000..0155d70 --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/entity/interfacing/LocalInterfaceEntity.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.airavata.jupyter.api.entity.interfacing; + +import org.hibernate.annotations.GenericGenerator; + +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.GeneratedValue; +import javax.persistence.Id; + +@Entity(name = "LOCAL_INTERFACE") +public class LocalInterfaceEntity { + + @Id + @Column(name = "ARCHIVE_ID") + @GeneratedValue(generator = "uuid") + @GenericGenerator(name = "uuid", strategy = "uuid2") + private String id; + + @Column(name = "WORKING_DIR") + private String workingDirectory; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getWorkingDirectory() { + return workingDirectory; + } + + public void setWorkingDirectory(String workingDirectory) { + this.workingDirectory = workingDirectory; + } +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/entity/interfacing/SSHInterfaceEntity.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/entity/interfacing/SSHInterfaceEntity.java new file mode 100644 index 0000000..6f5e59e --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/entity/interfacing/SSHInterfaceEntity.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.airavata.jupyter.api.entity.interfacing; + +import org.hibernate.annotations.GenericGenerator; + +import javax.persistence.*; + +@Entity(name = "SSH_INTERFACING") +public class SSHInterfaceEntity { + + @Id + @Column(name = "ARCHIVE_ID") + @GeneratedValue(generator = "uuid") + @GenericGenerator(name = "uuid", strategy = "uuid2") + private String id; + + @Column(name = "HOSTNAME") + private String hostName; + + @Column(name = "PORT") + private Integer port; + + @Column(name = "USER_NAME") + private String userName; + + @Lob + @Column(name = "PRIVATE_KEY") + private String privateKey; + + @Lob + @Column(name = "PUBLIC_KEY") + private String publicKey; + + @Column(name = "PASSPHRASE") + private String passphrase; + + @Column(name = "WORKING_DIR") + private String workingDirectory; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getHostName() { + return hostName; + } + + public void setHostName(String hostName) { + this.hostName = hostName; + } + + public Integer getPort() { + return port; + } + + public void setPort(Integer port) { + this.port = port; + } + + public String getPrivateKey() { + return privateKey; + } + + public void setPrivateKey(String privateKey) { + this.privateKey = privateKey; + } + + public String getPublicKey() { + return publicKey; + } + + public void setPublicKey(String publicKey) { + this.publicKey = publicKey; + } + + public String getPassphrase() { + return passphrase; + } + + public void setPassphrase(String passphrase) { + this.passphrase = passphrase; + } + + public String getWorkingDirectory() { + return workingDirectory; + } + + public void setWorkingDirectory(String workingDirectory) { + this.workingDirectory = workingDirectory; + } + + public String getUserName() { + return userName; + } + + public void setUserName(String userName) { + this.userName = userName; + } +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/entity/job/JobEntity.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/entity/job/JobEntity.java new file mode 100644 index 0000000..26d9886 --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/entity/job/JobEntity.java @@ -0,0 +1,48 @@ +package org.apache.airavata.jupyter.api.entity.job; + +import org.hibernate.annotations.GenericGenerator; + +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.GeneratedValue; +import javax.persistence.Id; + +@Entity(name = "JOB") +public class JobEntity { + + @Id + @Column(name = "JOB_ID") + @GeneratedValue(generator = "uuid") + @GenericGenerator(name = "uuid", strategy = "uuid2") + private String id; + + @Column(name = "REMOTE_WORK_PATH") + private String remoteWorkingPath; + + @Column(name = "LOCAL_WORK_PATH") + private String localWorkingPath; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getRemoteWorkingPath() { + return remoteWorkingPath; + } + + public void setRemoteWorkingPath(String remoteWorkingPath) { + this.remoteWorkingPath = remoteWorkingPath; + } + + public String getLocalWorkingPath() { + return localWorkingPath; + } + + public void setLocalWorkingPath(String localWorkingPath) { + this.localWorkingPath = localWorkingPath; + } +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/entity/job/JobStatusEntity.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/entity/job/JobStatusEntity.java new file mode 100644 index 0000000..8d9315d --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/entity/job/JobStatusEntity.java @@ -0,0 +1,78 @@ +package org.apache.airavata.jupyter.api.entity.job; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import org.hibernate.annotations.GenericGenerator; + +import javax.persistence.*; + +@Entity(name = "JOB_STATUS") +public class JobStatusEntity { + + public enum State { + SUBMITTED, + RUNNING, + FAILED, + COMPLETED, + CANCELLED + } + + @Id + @Column(name = "JOB_STATUS_ID") + @GeneratedValue(generator = "uuid") + @GenericGenerator(name = "uuid", strategy = "uuid2") + private String id; + + @Column(name = "STATE") + private State state; + + @Column(name = "UPDATED_TIME") + private long updatedTime; + + @JsonIgnoreProperties({"hibernateLazyInitializer", "handler"}) + @ManyToOne(cascade = CascadeType.ALL, fetch = FetchType.LAZY) + @JoinColumn(name = "JOB_ID", nullable = false) + private JobEntity jobEntity; + + @Column(name = "JOB_ID", insertable = false, updatable = false) + private String jobId; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public State getState() { + return state; + } + + public void setState(State state) { + this.state = state; + } + + public long getUpdatedTime() { + return updatedTime; + } + + public void setUpdatedTime(long updatedTime) { + this.updatedTime = updatedTime; + } + + public JobEntity getJobEntity() { + return jobEntity; + } + + public void setJobEntity(JobEntity jobEntity) { + this.jobEntity = jobEntity; + } + + public String getJobId() { + return jobId; + } + + public void setJobId(String jobId) { + this.jobId = jobId; + } +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/entity/remote/ComputeEntity.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/entity/remote/ComputeEntity.java new file mode 100644 index 0000000..dfc4609 --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/entity/remote/ComputeEntity.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.airavata.jupyter.api.entity.remote; + +import org.hibernate.annotations.GenericGenerator; + +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.GeneratedValue; +import javax.persistence.Id; + +@Entity(name = "LOCAL_COMPUTE") +public class ComputeEntity { + + public enum InterfaceType { + LOCAL, SSH + } + + public enum SubmitterType { + FORK, SLURM + } + + @Id + @Column(name = "ARCHIVE_ID") + @GeneratedValue(generator = "uuid") + @GenericGenerator(name = "uuid", strategy = "uuid2") + private String id; + + @Column(name = "COMPUTE_NAME") + private String computeName; + + @Column(name = "INTERFACE_TYPE") + private InterfaceType interfaceType; + + @Column(name = "INTERFACE_ID") + private String interfaceId; + + @Column(name = "SUBMITTER_TYPE") + private SubmitterType submitterType; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getComputeName() { + return computeName; + } + + public void setComputeName(String computeName) { + this.computeName = computeName; + } + + public InterfaceType getInterfaceType() { + return interfaceType; + } + + public void setInterfaceType(InterfaceType interfaceType) { + this.interfaceType = interfaceType; + } + + public String getInterfaceId() { + return interfaceId; + } + + public void setInterfaceId(String interfaceId) { + this.interfaceId = interfaceId; + } + + public SubmitterType getSubmitterType() { + return submitterType; + } + + public void setSubmitterType(SubmitterType submitterType) { + this.submitterType = submitterType; + } +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/repo/ComputeRepository.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/repo/ComputeRepository.java new file mode 100644 index 0000000..3d0febb --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/repo/ComputeRepository.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.airavata.jupyter.api.repo; + +import org.apache.airavata.jupyter.api.entity.remote.ComputeEntity; +import org.springframework.data.repository.CrudRepository; + +public interface ComputeRepository extends CrudRepository { +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/repo/JobRepository.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/repo/JobRepository.java new file mode 100644 index 0000000..100fc5a --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/repo/JobRepository.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.airavata.jupyter.api.repo; + +import org.apache.airavata.jupyter.api.entity.job.JobEntity; +import org.springframework.data.repository.CrudRepository; + +public interface JobRepository extends CrudRepository { + +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/repo/JobStatusRepository.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/repo/JobStatusRepository.java new file mode 100644 index 0000000..56bb7ed --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/repo/JobStatusRepository.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.airavata.jupyter.api.repo; + +import org.apache.airavata.jupyter.api.entity.job.JobStatusEntity; +import org.springframework.data.repository.CrudRepository; + +import java.util.List; +import java.util.Optional; + +public interface JobStatusRepository extends CrudRepository { + Optional findFirstByJobIdOrderByUpdatedTimeAsc(String jobId); +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/repo/LocalInterfaceRepository.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/repo/LocalInterfaceRepository.java new file mode 100644 index 0000000..1484fe9 --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/repo/LocalInterfaceRepository.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.airavata.jupyter.api.repo; + +import org.apache.airavata.jupyter.api.entity.interfacing.LocalInterfaceEntity; +import org.springframework.data.repository.CrudRepository; + +public interface LocalInterfaceRepository extends CrudRepository { +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/repo/SSHInterfaceRepository.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/repo/SSHInterfaceRepository.java new file mode 100644 index 0000000..12b7d78 --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/repo/SSHInterfaceRepository.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.airavata.jupyter.api.repo; + +import org.apache.airavata.jupyter.api.entity.interfacing.SSHInterfaceEntity; +import org.springframework.data.repository.CrudRepository; + +public interface SSHInterfaceRepository extends CrudRepository { +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/interfacing/InterfacingProtocol.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/interfacing/InterfacingProtocol.java new file mode 100644 index 0000000..501aa5f --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/interfacing/InterfacingProtocol.java @@ -0,0 +1,49 @@ +package org.apache.airavata.jupyter.api.util.remote.interfacing; + +public abstract class InterfacingProtocol { + + private String remoteWorkingDir ; + + public InterfacingProtocol(String remoteWorkingDir) { + this.remoteWorkingDir = remoteWorkingDir; + } + + public class ExecutionResponse { + private String stdOut; + private String stdErr; + private int code; + + public String getStdOut() { + return stdOut; + } + + public void setStdOut(String stdOut) { + this.stdOut = stdOut; + } + + public String getStdErr() { + return stdErr; + } + + public void setStdErr(String stdErr) { + this.stdErr = stdErr; + } + + public int getCode() { + return code; + } + + public void setCode(int code) { + this.code = code; + } + } + + public abstract boolean createDirectory(String relativePath) throws Exception; + public abstract boolean transferFileToRemote(String localPath, String remoteRelativePath) throws Exception; + public abstract boolean transferFileFromRemote(String remoteRelativePath, String localPath) throws Exception; + public abstract ExecutionResponse executeCommand(String relativeWorkDir, String command) throws Exception; + + public String getRemoteWorkingDir() { + return remoteWorkingDir; + } +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/interfacing/LocalInterfacingProtocol.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/interfacing/LocalInterfacingProtocol.java new file mode 100644 index 0000000..588972d --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/interfacing/LocalInterfacingProtocol.java @@ -0,0 +1,72 @@ +package org.apache.airavata.jupyter.api.util.remote.interfacing; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.*; +import java.nio.file.Files; +import java.nio.file.Path; + +public class LocalInterfacingProtocol extends InterfacingProtocol { + + private static final Logger logger = LoggerFactory.getLogger(LocalInterfacingProtocol.class); + + public LocalInterfacingProtocol(String remoteWorkingDir) { + super(remoteWorkingDir); + } + + @Override + public boolean createDirectory(String relativePath) throws Exception { + Files.createDirectories(Path.of(getRemoteWorkingDir(), relativePath)); + return true; + } + + @Override + public boolean transferFileToRemote(String localPath, String remoteRelativePath) throws Exception { + Files.copy(Path.of(localPath), Path.of(getRemoteWorkingDir(), remoteRelativePath)); + return true; + } + + @Override + public boolean transferFileFromRemote(String remoteRelativePath, String localPath) throws Exception { + Files.copy(Path.of(getRemoteWorkingDir(), remoteRelativePath), Path.of(localPath)); + return true; + } + + @Override + public ExecutionResponse executeCommand(String relativeWorkDir, String command) throws Exception { + + Runtime rt = Runtime.getRuntime(); + String[] envs= {}; + Process proc = rt.exec(command, envs, new File(getRemoteWorkingDir() + "/" + relativeWorkDir)); + + BufferedReader stdInput = new BufferedReader(new + InputStreamReader(proc.getInputStream())); + + BufferedReader stdError = new BufferedReader(new + InputStreamReader(proc.getErrorStream())); + + + logger.info("Here is the standard output of the command:\n"); + String s = null; + StringBuilder stdOut = new StringBuilder(); + StringBuilder stdErr = new StringBuilder(); + while ((s = stdInput.readLine()) != null) { + logger.info(s); + stdOut.append(s).append("\n"); + } + + logger.info("Here is the standard error of the command (if any):\n"); + while ((s = stdError.readLine()) != null) { + logger.info(s); + stdErr.append(s).append("\n"); + } + + ExecutionResponse response = new ExecutionResponse(); + response.setStdOut(stdOut.toString()); + response.setStdErr(stdErr.toString()); + response.setCode(proc.exitValue()); + + return response; + } +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/interfacing/SSHInterfacingProtocol.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/interfacing/SSHInterfacingProtocol.java new file mode 100644 index 0000000..b10ec13 --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/interfacing/SSHInterfacingProtocol.java @@ -0,0 +1,163 @@ +package org.apache.airavata.jupyter.api.util.remote.interfacing; + +import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.connection.channel.direct.Session; +import net.schmizz.sshj.sftp.SFTPClient; +import net.schmizz.sshj.transport.verification.HostKeyVerifier; +import net.schmizz.sshj.transport.verification.PromiscuousVerifier; +import net.schmizz.sshj.userauth.keyprovider.KeyProvider; +import net.schmizz.sshj.userauth.method.*; +import net.schmizz.sshj.userauth.password.PasswordFinder; +import net.schmizz.sshj.userauth.password.PasswordUtils; +import net.schmizz.sshj.userauth.password.Resource; +import net.schmizz.sshj.xfer.scp.SCPFileTransfer; +import org.apache.airavata.jupyter.api.entity.interfacing.SSHInterfaceEntity; +import org.apache.commons.io.IOUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.io.StringWriter; +import java.security.PublicKey; +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +public class SSHInterfacingProtocol extends InterfacingProtocol { + + private static final Logger logger = LoggerFactory.getLogger(SSHInterfacingProtocol.class); + + private SSHInterfaceEntity sshInterfaceEntity; + private SSHClient sshClient; + + public SSHInterfacingProtocol(SSHInterfaceEntity sshInterfaceEntity, String remoteWorkingDir) throws Exception { + super(remoteWorkingDir); + this.sshInterfaceEntity = sshInterfaceEntity; + this.sshClient = createSSHClient(); + } + + + private SSHClient createSSHClient() throws Exception { + + try { + SSHClient sshClient = new SSHClient(); + sshClient.addHostKeyVerifier(new PromiscuousVerifier()); + sshClient.connect(sshInterfaceEntity.getHostName(), sshInterfaceEntity.getPort()); + sshClient.getConnection().getKeepAlive().setKeepAliveInterval(5); + + + PasswordFinder passwordFinder = sshInterfaceEntity.getPassphrase() != null ? + PasswordUtils.createOneOff(sshInterfaceEntity.getPassphrase().toCharArray()) : null; + + KeyProvider keyProvider = sshClient.loadKeys(sshInterfaceEntity.getPrivateKey(), + sshInterfaceEntity.getPublicKey(), passwordFinder); + + final List am = new LinkedList<>(); + // am.add(new AbstractAuthMethod("none") {}); + am.add(new AuthPublickey(keyProvider)); + + am.add(new AuthKeyboardInteractive(new ChallengeResponseProvider() { + @Override + public List getSubmethods() { + return new ArrayList<>(); + } + + @Override + public void init(Resource resource, String name, String instruction) { + + } + + @Override + public char[] getResponse(String prompt, boolean echo) { + return new char[0]; + } + + @Override + public boolean shouldRetry() { + return false; + } + })); + + sshClient.auth(sshInterfaceEntity.getUserName(), am); + return sshClient; + } catch (Exception e) { + logger.error("Failed to create ssh connection for host {} on port {} and user {}", + sshInterfaceEntity.getHostName(), + sshInterfaceEntity.getPort(), + sshInterfaceEntity.getUserName()); + throw e; + } + } + + @Override + public boolean createDirectory(String relativePath) throws Exception { + if (!sshClient.isConnected()) { + createSSHClient(); + } + + SFTPClient sftpClient = sshClient.newSFTPClient(); + sftpClient.mkdirs(getRemoteWorkingDir() + "/" + relativePath); + sftpClient.close(); + return true; + } + + @Override + public boolean transferFileToRemote(String localPath, String remoteRelativePath) throws Exception { + + if (!sshClient.isConnected()) { + createSSHClient(); + } + + SCPFileTransfer scpFileTransfer = sshClient.newSCPFileTransfer(); + scpFileTransfer.upload(localPath, getRemoteWorkingDir() + "/" + remoteRelativePath); + return true; + } + + @Override + public boolean transferFileFromRemote(String remoteRelativePath, String localPath) throws Exception { + + if (!sshClient.isConnected()) { + createSSHClient(); + } + + SCPFileTransfer scpFileTransfer = sshClient.newSCPFileTransfer(); + scpFileTransfer.download(getRemoteWorkingDir() + "/" + remoteRelativePath, localPath); + return true; + } + + @Override + public ExecutionResponse executeCommand(String relativeWorkDir, String command) throws Exception { + + if (!sshClient.isConnected()) { + createSSHClient(); + } + + Session session = null; + ExecutionResponse response = null; + try { + session = sshClient.startSession(); + final Session.Command execResult = session.exec("cd " + getRemoteWorkingDir() + "/" + relativeWorkDir + "; " + command); + response = new ExecutionResponse(); + response.setStdOut(readStringFromStream(execResult.getInputStream())); + response.setStdErr(readStringFromStream(execResult.getErrorStream())); + + execResult.join(5, TimeUnit.SECONDS); + response.setCode(execResult.getExitStatus()); + + } finally { + if (session != null) { + session.close(); + } + } + return response; + } + + + private String readStringFromStream(InputStream is) throws IOException { + StringWriter writer = new StringWriter(); + IOUtils.copy(is, writer, "UTF-8"); + return writer.toString(); + } +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/submitters/ForkJobSubmitter.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/submitters/ForkJobSubmitter.java new file mode 100644 index 0000000..926604d --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/submitters/ForkJobSubmitter.java @@ -0,0 +1,149 @@ +package org.apache.airavata.jupyter.api.util.remote.submitters; + +import org.apache.airavata.jupyter.api.entity.job.JobEntity; +import org.apache.airavata.jupyter.api.entity.job.JobStatusEntity; +import org.apache.airavata.jupyter.api.repo.JobRepository; +import org.apache.airavata.jupyter.api.repo.JobStatusRepository; +import org.apache.airavata.jupyter.api.util.remote.interfacing.InterfacingProtocol; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedWriter; +import java.io.FileWriter; +import java.util.UUID; + +public class ForkJobSubmitter implements JobSubmitter { + + private static final Logger logger = LoggerFactory.getLogger(ForkJobSubmitter.class); + + private InterfacingProtocol interfacingProtocol; + + private String codeTemplate = "import dill as pickle\n" + + "import shutil\n" + + "import json\n" + + "import subprocess\n"+ + "import sys\n" + + "import os\n" + + "\n" + + "f = open('ARCHIVE/files.json')\n" + + "files_json = json.load(f)\n" + + "for v in files_json:\n" + + " target_path = files_json[v]\n" + + " if not target_path.startswith(\"/\"):\n" + + " target_path = target_path\n" + + "\n" + + " dir_path = os.path.dirname(target_path)\n" + + " if not dir_path == \"\":\n" + + " os.makedirs(dir_path, exist_ok = True)\n" + + "\n" + + " shutil.copyfile('ARCHIVE/' + v, target_path)\n" + + "f = open('ARCHIVE/dependencies.json')\n" + + "dep_json = json.load(f)\n" + + "subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"--upgrade\", \"pip\"])\n" + + "for dep_name in dep_json:\n" + + " dep_version = dep_json[dep_name]\n" + + " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", dep_name + \"==\" + dep_version])\n" + + "with open('ARCHIVE/context.p', 'rb') as f:\n" + + " context = pickle.load(f)\n" + + "\n" + + "with open('ARCHIVE/code.txt') as f:" + + "\n" + + " code = f.read()" + + "\n" + + "exec(code, None, context)" + + "\n" + + "with open('ARCHIVE/final-context.p', 'wb') as f:\n" + + " pickle.dump(context, f)\n"; + + private String localWorkingDir; + private JobRepository jobRepository; + private JobStatusRepository jobStatusRepository; + + public ForkJobSubmitter(String localWorkingDir, + InterfacingProtocol interfacingProtocol, + JobRepository jobRepository, + JobStatusRepository jobStatusRepository) { + this.localWorkingDir = localWorkingDir; + this.interfacingProtocol = interfacingProtocol; + this.jobRepository = jobRepository; + this.jobStatusRepository = jobStatusRepository; + } + + @Override + public String submitJob(String archivePath, String sessionId) throws Exception { + String expDir = UUID.randomUUID().toString(); + String sessionDir = "sessions/" + sessionId; + logger.info("Using experiment directory {} and working directory {}", + interfacingProtocol.getRemoteWorkingDir() + "/" + expDir, localWorkingDir); + interfacingProtocol.createDirectory(expDir); + + // This to store a virtual environment which can be reused across the same session + interfacingProtocol.createDirectory(sessionDir); + logger.info("Created exp dir in {} and session dir in {}", expDir, sessionDir); + + InterfacingProtocol.ExecutionResponse response = interfacingProtocol.executeCommand(sessionDir, "python3 -m venv venv --system-site-packages"); + if (response.getCode() != 0) { + logger.error("Failed to create the virtual environment. Stderr: " + response.getStdErr() + " Std out " + response.getStdOut()); + throw new Exception("Failed to create the virtual environment"); + } + + String pythonCommand = interfacingProtocol.getRemoteWorkingDir() + "/" + sessionDir + "/venv/bin/python3"; + String pipCommand = interfacingProtocol.getRemoteWorkingDir() + "/" + sessionDir + "/venv/bin/pip3"; + // TODO Save in database + + interfacingProtocol.transferFileToRemote(archivePath, expDir + "/ARCHIVE.zip"); + interfacingProtocol.executeCommand(expDir, "unzip ARCHIVE.zip -d ARCHIVE"); + + BufferedWriter writer = new BufferedWriter(new FileWriter(localWorkingDir + "/wrapper_code.py")); + writer.write(codeTemplate); + writer.flush(); + writer.close(); + + interfacingProtocol.transferFileToRemote(localWorkingDir + "/wrapper_code.py", expDir + "/wrapper_code.py"); + + interfacingProtocol.executeCommand(expDir, pipCommand + " install dill"); + InterfacingProtocol.ExecutionResponse executionResponse = interfacingProtocol.executeCommand(expDir, pythonCommand + " wrapper_code.py"); + + writer = new BufferedWriter(new FileWriter(localWorkingDir + "/stdout.txt")); + writer.write(executionResponse.getStdOut()); + writer.flush(); + writer.close(); + + writer = new BufferedWriter(new FileWriter(localWorkingDir + "/stderr.txt")); + writer.write(executionResponse.getStdErr()); + writer.flush(); + writer.close(); + + writer = new BufferedWriter(new FileWriter(localWorkingDir + "/state-code.txt")); + writer.write(executionResponse.getCode() + ""); + writer.flush(); + writer.close(); + + interfacingProtocol.transferFileFromRemote(expDir + "/ARCHIVE/final-context.p", localWorkingDir + "/final-context.p"); + + logger.info("Completed running cell and placed output in {}", localWorkingDir); + + JobEntity jobEntity = new JobEntity(); + jobEntity.setLocalWorkingPath(localWorkingDir); + jobEntity.setRemoteWorkingPath(expDir); + JobEntity savedJob = jobRepository.save(jobEntity); + + JobStatusEntity jobStatusEntity = new JobStatusEntity(); + jobStatusEntity.setJobEntity(jobEntity); + jobStatusEntity.setState(JobStatusEntity.State.COMPLETED); + jobStatusEntity.setUpdatedTime(System.currentTimeMillis()); + jobStatusRepository.save(jobStatusEntity); + + return savedJob.getId(); + } + + @Override + public String getJobStatus(String jobId) { + return null; + } + + @Override + public String cancelJob(String jobId) { + return null; + } +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/submitters/JobSubmitter.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/submitters/JobSubmitter.java new file mode 100644 index 0000000..b0bdcdb --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/submitters/JobSubmitter.java @@ -0,0 +1,8 @@ +package org.apache.airavata.jupyter.api.util.remote.submitters; + +public interface JobSubmitter { + + public String submitJob(String archivePath, String sessionId) throws Exception; + public String getJobStatus(String jobId) throws Exception; + public String cancelJob(String jobId) throws Exception; +} diff --git a/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/submitters/SlurmJobSubmitter.java b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/submitters/SlurmJobSubmitter.java new file mode 100644 index 0000000..e151952 --- /dev/null +++ b/jupyter-platform/platform-api/src/main/java/org/apache/airavata/jupyter/api/util/remote/submitters/SlurmJobSubmitter.java @@ -0,0 +1,18 @@ +package org.apache.airavata.jupyter.api.util.remote.submitters; + +public class SlurmJobSubmitter implements JobSubmitter { + @Override + public String submitJob(String archivePath, String sessionId) throws Exception { + return null; + } + + @Override + public String getJobStatus(String jobId) throws Exception { + return null; + } + + @Override + public String cancelJob(String jobId) throws Exception { + return null; + } +} diff --git a/jupyter-platform/platform-api/src/main/resources/log4j2.xml b/jupyter-platform/platform-api/src/main/resources/log4j2.xml index ec81228..7b7ec2f 100644 --- a/jupyter-platform/platform-api/src/main/resources/log4j2.xml +++ b/jupyter-platform/platform-api/src/main/resources/log4j2.xml @@ -24,7 +24,7 @@ - + diff --git a/tools/state_capture_magic/StateCaptureMagic/__init__.py b/tools/state_capture_magic/StateCaptureMagic/__init__.py index 6070e4a..2ed0cea 100644 --- a/tools/state_capture_magic/StateCaptureMagic/__init__.py +++ b/tools/state_capture_magic/StateCaptureMagic/__init__.py @@ -28,6 +28,7 @@ import os.path from os.path import exists import socket +import uuid def load_ipython_extension(ipython): ipython.register_magics(StateCaptureMagic) @@ -40,6 +41,9 @@ class StateCaptureMagic(Magics): def load_local_context(self, line, cell="", local_ns=None): context_file = "/opt/ARCHIVE/context.p" + self.load_context_from_file(context_file, local_ns) + + def load_context_from_file(self, context_file, local_ns): if exists(context_file): with open(context_file, "rb") as input_file: @@ -52,6 +56,7 @@ def load_local_context(self, line, cell="", local_ns=None): else: print("No archive is loaded or context is not exported to the archive") + def control_tracing(self, turn_on): pid = os.getpid() @@ -65,22 +70,148 @@ def control_tracing(self, turn_on): try: if turn_on: - print("Turning on tracing") + # print("Turning on tracing") message = "START:" + str(pid) sock.sendall(str.encode(message)) else: - print("Turning off tracing") + # print("Turning off tracing") message = "STOP:" + str(pid) sock.sendall(str.encode(message)) finally: sock.close() + @cell_magic + @needs_local_scope + def run_hpc(self, line, cell, local_ns=None): + + + try: + print("Session ID " + self.session_id) + except AttributeError: + self.session_id = str(uuid.uuid4()) + print("Session ID " + self.session_id) + + self.control_tracing(False) + + parameters = line.split(",") + + uploadServer = None + archiveName = "HPC Export" + compute_id = None + ignoredDeps = set() + + + for param in parameters: + param = param.strip() + + if param.startswith("uploadServer"): + uploadServer = param.split("=")[1] + + if param.startswith("computeResourceId"): + compute_id = param.split("=")[1] + + if param.startswith("ignoredDependencies"): + depListStr = param.split("=")[1] + depsList = depListStr.split(";") + for dep in depsList: + ignoredDeps.add(dep) + + accessed_files = self.get_accessed_files() + dependencies = self.get_dependencies(local_ns, ignoredDeps) + archive_dir = "ARCHIVE" + + dirpath = Path(archive_dir) + if dirpath.exists() and dirpath.is_dir(): + shutil.rmtree(dirpath) + + os.mkdir(archive_dir) + + archive_dict = {} + for f in accessed_files: + if not f == "" and exists(f): + letters = string.ascii_lowercase + random_name = ''.join(random.choice(letters) for i in range(10)); + archive_dict[random_name] = f + shutil.copyfile(f, archive_dir + "/" + random_name) + + f = open(archive_dir + "/files.json", "w") + json.dump(archive_dict, f) + f.close() + + f = open(archive_dir + "/dependencies.json", "w") + json.dump(dependencies, f) + f.close() + + var_context = self.get_local_context(local_ns) + + pickle.dump( var_context, open( archive_dir + "/context.p", "wb" )) + + f = open(archive_dir + "/code.txt", "w") + f.write(cell) + f.close() + + shutil.make_archive("ARCHIVE", 'zip', archive_dir) + # print("Download the state export ") + # display(FileLink("ARCHIVE.zip")) + + if uploadServer: + archiveId = self.upload_and_register_archive_in_server(uploadServer, archiveName) + if archiveId: + # print("Uploaded to archive " + archiveId) + run_url = uploadServer + "/remote/run/" + compute_id + "/" + archiveId + "/" + self.session_id + # print("Executing run url " + run_url) + response = requests.get(run_url) + if response.status_code == 200: + + job_id = response.json()['jobId'] + # print ("Received job id " + job_id) + + monitor_url = uploadServer + "/job/status/" + job_id + response = requests.get(monitor_url) + if response.status_code == 200: + job_state = response.json()["state"] + # print("Job state is " + job_state) + if job_state == "COMPLETED": + state_download_url = uploadServer + "/archive/download/" + job_id + downloaded_zip = self.download_file_from_url(state_download_url) + if downloaded_zip: + shutil.unpack_archive(downloaded_zip, "LATEST_STATE") + self.load_context_from_file("LATEST_STATE/final-context.p", local_ns) + with open('LATEST_STATE/stdout.txt', 'r') as f: + print(f.read()) + + with open('LATEST_STATE/stderr.txt', 'r') as f: + print(f.read()) + else: + print("Downloading state archive from remote failed") + else: + print("Failed while monitoring job " + job_id) + else: + print("Cell execution submission failed") + + else: + print("Upload to server failed") + + self.control_tracing(True) + + def download_file_from_url(self, download_url): + resp = requests.get(download_url, allow_redirects=True) + disp_header = resp.headers.get('content-disposition') + if disp_header : + file_name = re.findall('filename=(.+)', disp_header) + if len(file_name) == 0: + return None + + open(file_name[0], 'wb').write(resp.content) + print("File " + file_name[0] + " was downloaded") + return file_name[0] + else: + print("Could not find the file name in download response") + @line_magic @needs_local_scope def export_states(self, line, cell="", local_ns=None): - pid = os.getpid() - self.control_tracing(False) parameters = line.split(",") @@ -108,68 +239,13 @@ def export_states(self, line, cell="", local_ns=None): if flag == "True": captureLocalContext = True - #print(createArchive) - #print(uploadServer) - - - log_file = "/tmp/p" + str(pid) - f = open(log_file) - raw = f.read() - lines = raw.splitlines() - - ignore_list = ['/usr', '/lib','/home/dimuthu/.ipython/', "/dev", ".so", "/proc", - "/etc", "/tmp/pip-", "/home/dimuthu/.cache", "/root/.cache", - "ARCHIVE", "/root/.local", "dependencies.json", "context.p", - "files.json", "metadata.json"] - notebook_name = self.get_notebook_name() - for path in sys.path[1:]: - if path: - ignore_list.append(path) - - accessed_files = set() - for line in lines: - if line.count("O_NOFOLLOW") > 0: - continue - if line.count("O_DIRECTORY") > 0: - continue - if line.count("AT_FDCWD") == 0: - continue - - if line.count("\"") > 1: - start = line.index("\"") - end = line.index("\"", start + 1) - path = line[start+1 : end] - should_ignore = False - for ignore in ignore_list: - if path.count(ignore): - should_ignore = True - break - if not should_ignore: - accessed_files.add(path) - - if log_file in accessed_files: - accessed_files.remove(log_file) + accessed_files = self.get_accessed_files() accessed_files.add(notebook_name) - - def imports(): - for name, val in local_ns.items(): - if isinstance(val, types.ModuleType): - yield val.__name__ - - import_list = list(imports()) - dependencies = {} - for imp in import_list: - try: - if imp.count(".") > 0: - imp = imp.split(".")[0] - dependencies[imp] = version(imp) - except PackageNotFoundError: - pass + dependencies = self.get_dependencies(local_ns) if createArchive: - archive_dict = {} archive_dir = "ARCHIVE" dirpath = Path(archive_dir) @@ -178,6 +254,7 @@ def imports(): os.mkdir(archive_dir) + archive_dict = {} for f in accessed_files: if not f == "" and exists(f): letters = string.ascii_lowercase @@ -200,15 +277,8 @@ def imports(): f.close() if captureLocalContext: - local_variables = self.get_magic_out("who")[:-1] - var_context = {} - for var_name in local_variables: - var_context[var_name] = local_ns[var_name] - try: - pickle.dumps(var_context) - except: - print("Warning: Variable " + var_name + " can not be exported") - var_context.pop(var_name) + var_context = self.get_local_context(local_ns) + pickle.dump( var_context, open( archive_dir + "/context.p", "wb" )) shutil.make_archive("ARCHIVE", 'zip', archive_dir) @@ -216,7 +286,7 @@ def imports(): display(FileLink("ARCHIVE.zip")) if uploadServer: - self.upload_to_server(uploadServer, archiveName) + self.upload_and_register_archive_in_server(uploadServer, archiveName) self.control_tracing(True) @@ -224,15 +294,25 @@ def imports(): return {"accessed_files": list(accessed_files), "dependencies": dependencies} - def upload_to_server(self, base_url, archiveName): + def upload_archive_to_server(self, base_url): + headers={'Accept': 'application/json, text/plain, */*', + 'Accept-Encoding': 'gzip, deflate, br', + 'Accept-Language': 'en-US,en;q=0.9', + 'Connection': 'keep-alive'} + + files = {'file': open('ARCHIVE.zip', 'rb')} + response = requests.post(base_url + "/archive/upload", data={}, headers=headers, files=files) + return response + + def upload_and_register_archive_in_server(self, base_url, archiveName): headers={'Accept': 'application/json, text/plain, */*', 'Accept-Encoding': 'gzip, deflate, br', 'Accept-Language': 'en-US,en;q=0.9', 'Connection': 'keep-alive'} - files = {'file': open('ARCHIVE.zip', 'rb')} - response = requests.post(base_url + "/archive/upload", data={}, headers=headers, files=files) + response = self.upload_archive_to_server(base_url) + if response.status_code == 200: response_json = response.json() archive_json = {"path": response_json["path"],"description": archiveName} @@ -240,6 +320,7 @@ def upload_to_server(self, base_url, archiveName): response = requests.post(base_url + '/archive/', data=json.dumps(archive_json), headers=headers) if response.status_code == 200: print("Archive with name " + archive_json["description"] + " was uploaded") + return response.json()["id"] else: print("Failed to create archive metadata in server with status code " + str(response.status_code)) else: @@ -268,3 +349,77 @@ def get_notebook_name(self): if nn['kernel']['id'] == kernel_id: relative_path = nn['notebook']['path'] return os.path.join(ss['notebook_dir'], relative_path) + + def get_accessed_files(self): + + pid = os.getpid() + log_file = "/tmp/p" + str(pid) + f = open(log_file) + raw = f.read() + lines = raw.splitlines() + + ignore_list = ['/usr', '/lib','/home/dimuthu/.ipython/', "/dev", ".so", "/proc", + "/etc", "/tmp/pip-", "/home/dimuthu/.cache", "/root/.cache", + "ARCHIVE", "/root/.local", "dependencies.json", "context.p", "/root/.keras", + "files.json", "metadata.json"] + + for path in sys.path[1:]: + if path: + ignore_list.append(path) + + accessed_files = set() + for line in lines: + if line.count("O_NOFOLLOW") > 0: + continue + if line.count("O_DIRECTORY") > 0: + continue + if line.count("AT_FDCWD") == 0: + continue + + if line.count("\"") > 1: + start = line.index("\"") + end = line.index("\"", start + 1) + path = line[start+1 : end] + should_ignore = False + for ignore in ignore_list: + if path.count(ignore): + should_ignore = True + break + if not should_ignore: + accessed_files.add(path) + + if log_file in accessed_files: + accessed_files.remove(log_file) + + return accessed_files + + def get_dependencies(self, local_ns, ignoredDeps = set()): + def imports(): + for name, val in local_ns.items(): + if isinstance(val, types.ModuleType): + yield val.__name__ + + import_list = list(imports()) + dependencies = {} + for imp in import_list: + try: + if imp.count(".") > 0: + imp = imp.split(".")[0] + if not imp in ignoredDeps: + dependencies[imp] = version(imp) + except PackageNotFoundError: + pass + return dependencies + + def get_local_context(self, local_ns): + var_context = {} + local_variables = self.get_magic_out("who")[:-1] + for var_name in local_variables: + var_context[var_name] = local_ns[var_name] + try: + pickle.dumps(var_context) + except: + print("Warning: Variable " + var_name + " can not be exported") + var_context.pop(var_name) + + return var_context