Skip to content

Commit

Permalink
Merge pull request #1355 from Giskard-AI/GSK-1623-secure-readonly-dem…
Browse files Browse the repository at this point in the history
…o-space

[GSK-1623] Secure Giskard readonly demo space at Hugging Face Spaces
  • Loading branch information
kevinmessiaen authored Sep 20, 2023
2 parents eea8171 + 9cdf5b1 commit d5eaa73
Show file tree
Hide file tree
Showing 32 changed files with 433 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public class ApplicationProperties {
private long tokenValidityInSecondsForRememberMe = 2592000; // 30 days;
private CorsConfiguration cors = new CorsConfiguration();

private String defaultApiKey;
private String hfDemoSpaceUnlockToken;
private String mixpanelProjectKey;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ public class SecurityConfiguration {
private final ApiKeyService apiKeyService;


public static final String GISKARD_API_ENDPOINTS = "/api/**";

@Bean
public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
http
Expand Down Expand Up @@ -75,7 +77,7 @@ public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
antMatcher("/management/**")
).hasAuthority(AuthoritiesConstants.ADMIN)
.requestMatchers(antMatcher("/public-api/**")).hasAuthority(AuthoritiesConstants.API)
.requestMatchers(antMatcher("/api/**")).authenticated()
.requestMatchers(antMatcher(GISKARD_API_ENDPOINTS)).authenticated()
)
.sessionManagement(session -> session.sessionCreationPolicy(SessionCreationPolicy.STATELESS))
.apply(securityConfigurerAdapter());
Expand Down
4 changes: 4 additions & 0 deletions backend/src/main/java/ai/giskard/domain/ApiKey.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package ai.giskard.domain;

import ai.giskard.security.GalleryDatabaseOperationListener;
import jakarta.persistence.*;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import org.apache.commons.lang3.RandomStringUtils;

import java.util.UUID;

@Entity(name = "api_keys")
@Getter
@Setter
@NoArgsConstructor
@EntityListeners(GalleryDatabaseOperationListener.class)
public class ApiKey extends AbstractAuditingEntity {

public static final String PREFIX = "gsk-";
Expand Down
3 changes: 2 additions & 1 deletion backend/src/main/java/ai/giskard/domain/Project.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ai.giskard.domain.ml.Dataset;
import ai.giskard.domain.ml.ProjectModel;
import ai.giskard.domain.ml.TestSuite;
import ai.giskard.security.GalleryDatabaseOperationListener;
import com.fasterxml.jackson.annotation.JsonIdentityInfo;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
Expand All @@ -22,7 +23,7 @@

@Entity(name = "projects")
@NoArgsConstructor
@EntityListeners(AuditingEntityListener.class)
@EntityListeners({AuditingEntityListener.class, GalleryDatabaseOperationListener.class})
@JsonIdentityInfo(generator = ObjectIdGenerators.PropertyGenerator.class, property = "key")
public class Project extends AbstractAuditingEntity {
@Serial
Expand Down
3 changes: 3 additions & 0 deletions backend/src/main/java/ai/giskard/domain/Role.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package ai.giskard.domain;

import ai.giskard.security.GalleryDatabaseOperationListener;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.EntityListeners;
import jakarta.persistence.Table;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
Expand All @@ -12,6 +14,7 @@
*/
@Entity
@Table(name = "role")
@EntityListeners(GalleryDatabaseOperationListener.class)
public class Role extends BaseEntity {
@lombok.Setter
@lombok.Getter
Expand Down
2 changes: 2 additions & 0 deletions backend/src/main/java/ai/giskard/domain/User.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.giskard.domain;

import ai.giskard.config.Constants;
import ai.giskard.security.GalleryDatabaseOperationListener;
import com.fasterxml.jackson.annotation.JsonIdentityInfo;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.ObjectIdGenerators;
Expand Down Expand Up @@ -29,6 +30,7 @@
@Setter
@NotNull
@JsonIdentityInfo(generator = ObjectIdGenerators.PropertyGenerator.class, property = "login")
@EntityListeners(GalleryDatabaseOperationListener.class)
public class User extends AbstractAuditingEntity {
@Serial
private static final long serialVersionUID = 0L;
Expand Down
2 changes: 2 additions & 0 deletions backend/src/main/java/ai/giskard/domain/ml/ProjectModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import ai.giskard.domain.AbstractAuditingEntity;
import ai.giskard.domain.Project;
import ai.giskard.security.GalleryDatabaseOperationListener;
import ai.giskard.utils.SimpleJSONStringAttributeConverter;
import com.fasterxml.jackson.annotation.JsonBackReference;
import lombok.AllArgsConstructor;
Expand All @@ -20,6 +21,7 @@
@Setter
@NoArgsConstructor
@AllArgsConstructor
@EntityListeners(GalleryDatabaseOperationListener.class)
public class ProjectModel extends AbstractAuditingEntity {
@Serial
private static final long serialVersionUID = 0L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ public final class AuthoritiesConstants {
ADMIN, "Admin"
);

public static final String[] AUTHORITIES = AUTHORITY_NAMES.keySet().toArray(new String[0]);

private AuthoritiesConstants() {
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package ai.giskard.security;

import ai.giskard.service.GeneralSettingsService;
import ai.giskard.service.InitService;
import ai.giskard.web.rest.errors.GalleryDemoSpaceException;
import jakarta.persistence.PrePersist;
import jakarta.persistence.PreRemove;
import jakarta.persistence.PreUpdate;
import org.springframework.beans.factory.annotation.Configurable;

@Configurable
public class GalleryDatabaseOperationListener {
@PrePersist
@PreUpdate
@PreRemove
void beforeEntityModification(Object entity) {
if (GeneralSettingsService.IS_RUNNING_IN_DEMO_HF_SPACES && !InitService.isUnlocked()) {
throw new GalleryDemoSpaceException("This is a read-only Giskard Gallery instance. You cannot modify entities " + entity.getClass().getName());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ public void doTrack(String eventName, JSONObject props) {
serverProps.put("Giskard Version", buildVersion);
serverProps.put("Giskard Plan", licenseService.getCurrentLicense().getPlanCode());
serverProps.put("Giskard LicenseID", licenseService.getCurrentLicense().getId());
serverProps.put("Is HuggingFace", GeneralSettingsService.isRunningInHFSpaces);
serverProps.put("HuggingFace Space ID", GeneralSettingsService.hfSpaceId);
serverProps.put("Is HuggingFace", GeneralSettingsService.IS_RUNNING_IN_HFSPACES);
serverProps.put("HuggingFace Space ID", GeneralSettingsService.HF_SPACE_ID);
messageBuilder.set(settingsService.getSettings().getInstanceId(), serverProps);
} catch (NoSuchElementException e) {
// Do not track when we failed to initialize the server properties
Expand Down
23 changes: 23 additions & 0 deletions backend/src/main/java/ai/giskard/service/ApiKeyService.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package ai.giskard.service;

import ai.giskard.config.ApplicationProperties;
import ai.giskard.domain.ApiKey;
import ai.giskard.domain.User;
import ai.giskard.repository.ApiKeyRepository;
import ai.giskard.repository.UserRepository;
import jakarta.annotation.PostConstruct;
import jakarta.validation.constraints.NotNull;
import lombok.RequiredArgsConstructor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

Expand All @@ -18,7 +22,11 @@
@Service
@RequiredArgsConstructor
public class ApiKeyService {
private final Logger log = LoggerFactory.getLogger(ApiKeyService.class);

private final ApiKeyRepository apiKeyRepository;
private final ApplicationProperties applicationProperties;
private final UserRepository userRepository;
private final Map<String, ApiKey> apiKeysCache = new ConcurrentHashMap<>();

@PostConstruct
Expand Down Expand Up @@ -48,4 +56,19 @@ public List<ApiKey> deleteKey(String username, UUID key) {
.ifPresent(k -> apiKeysCache.remove(k.getKey()));
return getKeys(username);
}

public void initDefaultApiKey() {
// Create a default API key if set in env
if (applicationProperties.getDefaultApiKey() != null) {
// Add a key provided from env to connect external MLWorker in HF
Optional<User> user = userRepository.findOneByLogin("admin");
if (user.isPresent() && ApiKey.doesStringLookLikeApiKey(applicationProperties.getDefaultApiKey())) {
ApiKey key = new ApiKey(user.get());
key.setKey(applicationProperties.getDefaultApiKey());
apiKeysCache.put(applicationProperties.getDefaultApiKey(), key);
} else {
log.warn("API Key provided but not conforming format.");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;

import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;

Expand All @@ -21,9 +22,21 @@ public class GeneralSettingsService {

private final GeneralSettingsRepository settingsRepository;

public static final boolean isRunningInHFSpaces = Stream.of("SPACE_REPO_NAME", "SPACE_ID", "SPACE_HOST").allMatch(System.getenv()::containsKey);
private static final String SPACE_REPO_NAME_ENV = "SPACE_REPO_NAME";
private static final String SPACE_HOST_ENV = "SPACE_HOST";
private static final String SPACE_ID_ENV = "SPACE_ID";
private static final String DEMO_SPACE_ID_ENV = "DEMO_SPACE_ID";
private static final String DEFAULT_DEMO_SPACE_ID = "giskardai/giskard";

public static final String hfSpaceId = System.getenv().get("SPACE_ID");
public static final boolean IS_RUNNING_IN_HFSPACES =
Stream.of(SPACE_REPO_NAME_ENV, SPACE_ID_ENV, SPACE_HOST_ENV).allMatch(System.getenv()::containsKey);

public static final String HF_SPACE_ID = System.getenv().get(SPACE_ID_ENV);

public static final boolean IS_RUNNING_IN_DEMO_HF_SPACES = Objects.equals(
System.getenv().get(SPACE_ID_ENV),
System.getenv().get(DEMO_SPACE_ID_ENV) == null ? DEFAULT_DEMO_SPACE_ID : System.getenv().get(DEMO_SPACE_ID_ENV)
);

public GeneralSettings getSettings() {
return deserializeSettings(settingsRepository.getMandatoryById(SerializedGiskardGeneralSettings.SINGLE_ID).getSettings());
Expand Down
13 changes: 11 additions & 2 deletions backend/src/main/java/ai/giskard/service/InitService.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import ai.giskard.security.AuthoritiesConstants;
import ai.giskard.service.ee.FeatureFlag;
import ai.giskard.service.ee.LicenseService;
import ai.giskard.utils.GalleryUnlockIndicator;
import ai.giskard.web.dto.DataUploadParamsDTO;
import ai.giskard.web.dto.ModelUploadParamsDTO;
import ai.giskard.web.dto.mapper.GiskardMapper;
Expand Down Expand Up @@ -74,8 +75,9 @@ public class InitService {
private final GeneralSettingsService generalSettingsService;
private final FileLocationService fileLocationService;
private final LicenseService licenseService;
private final ApiKeyService apiKeyService;
private Map<String, ProjectConfig> projects;
String[] mockKeys = stream(AuthoritiesConstants.AUTHORITIES).map(key -> key.replace("ROLE_", "")).toArray(String[]::new);
String[] mockKeys = stream(AuthoritiesConstants.AUTHORITY_NAMES.keySet().toArray(new String[0])).map(key -> key.replace("ROLE_", "")).toArray(String[]::new);
private final Map<String, String> users = stream(mockKeys).collect(Collectors.toMap(String::toLowerCase, String::toLowerCase));
private final DatasetRepository datasetRepository;
private final ModelRepository modelRepository;
Expand Down Expand Up @@ -143,19 +145,26 @@ public String getProjectByCreatorLogin(String login) {
return projects.entrySet().stream().filter(e -> e.getValue().creator.equals(login)).findFirst().orElseThrow().getValue().name;
}

static final GalleryUnlockIndicator lockIndicator = new GalleryUnlockIndicator();
public static boolean isUnlocked() { return lockIndicator.isUnlocked(); }
public static void setUnlocked(boolean unlocked) { lockIndicator.setUnlocked(unlocked); }

/**
* Initializing first authorities, mock users, and mock projects
*/
@EventListener(ApplicationReadyEvent.class)
public void init() {
projects = createProjectConfigMap();
generalSettingsService.saveIfNotExists(new GeneralSettings());
// change readonly
initAuthorities();
initUsers();
apiKeyService.initDefaultApiKey();
List<String> profiles = Arrays.asList(env.getActiveProfiles());
if (!profiles.contains("prod") && !profiles.contains("dev")) {
initProjects();
}
lockIndicator.setUnlocked(false);
}

/**
Expand Down Expand Up @@ -183,7 +192,7 @@ public void initUsers() {
* Initiating authorities with AuthoritiesConstants values
*/
public void initAuthorities() {
stream(AuthoritiesConstants.AUTHORITIES).forEach(authName -> {
stream(AuthoritiesConstants.AUTHORITY_NAMES.keySet().toArray(new String[0])).forEach(authName -> {
if (roleRepository.findByName(authName).isPresent()) {
logger.info("Authority {} already exists", authName);
return;
Expand Down
3 changes: 2 additions & 1 deletion backend/src/main/java/ai/giskard/service/ModelService.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ private MLWorkerWSRunModelForDataFrameDTO getRunModelForDataFrameResponse(Projec
param.setColumnDtypes(dataset.getColumnDtypes());
}

MLWorkerID workerID = model.getProject().isUsingInternalWorker() ? MLWorkerID.INTERNAL : MLWorkerID.EXTERNAL;
// Perform the runModelForDataFrame action and parse the reply
MLWorkerWSBaseDTO result = mlWorkerWSCommService.performAction(
model.getProject().isUsingInternalWorker() ? MLWorkerID.INTERNAL : MLWorkerID.EXTERNAL,
workerID,
MLWorkerWSAction.RUN_MODEL_FOR_DATA_FRAME,
param
);
Expand Down
10 changes: 10 additions & 0 deletions backend/src/main/java/ai/giskard/utils/GalleryUnlockIndicator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package ai.giskard.utils;

import lombok.Getter;
import lombok.Setter;

@Getter
@Setter
public class GalleryUnlockIndicator {
private boolean isUnlocked = true;
}
15 changes: 15 additions & 0 deletions backend/src/main/java/ai/giskard/web/dto/GalleryUnlockDTO.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package ai.giskard.web.dto;

import com.dataiku.j2ts.annotations.UIModel;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;

@Getter
@Setter
@NoArgsConstructor
@UIModel
public class GalleryUnlockDTO {
private String token;
private boolean unlocked;
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,7 @@ public static class AppInfoDTO {
private String hfSpaceId;
@JsonProperty(value = "isRunningOnHfSpaces")
private boolean isRunningOnHfSpaces;
@JsonProperty(value = "isDemoHfSpace")
private boolean isDemoHfSpace;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package ai.giskard.web.rest.controllers;

import ai.giskard.config.ApplicationProperties;
import ai.giskard.service.InitService;
import ai.giskard.web.dto.GalleryUnlockDTO;
import lombok.RequiredArgsConstructor;
import org.springframework.http.HttpStatus;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.server.ResponseStatusException;

import java.util.Objects;

@RestController
@RequiredArgsConstructor
@RequestMapping("/api/v2/hfs/unlock")
public class GalleryUnlockController {
private final ApplicationProperties applicationProperties;

@GetMapping("")
public GalleryUnlockDTO getUnlockStatus() {
GalleryUnlockDTO unlockDTO = new GalleryUnlockDTO();
unlockDTO.setUnlocked(InitService.isUnlocked());
unlockDTO.setToken(null);
return unlockDTO;
}

@PostMapping("")
public GalleryUnlockDTO setUnlockStatus(@RequestBody GalleryUnlockDTO unlockDTO) {
if (StringUtils.hasText(unlockDTO.getToken()) &&
Objects.equals(unlockDTO.getToken(), applicationProperties.getHfDemoSpaceUnlockToken())) {
InitService.setUnlocked(unlockDTO.isUnlocked());
unlockDTO.setUnlocked(InitService.isUnlocked());
unlockDTO.setToken(null);
return unlockDTO;
}
throw new ResponseStatusException(HttpStatus.FORBIDDEN, "Invalid unlock token");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ public AppConfigDTO getApplicationSettings(@AuthenticationPrincipal final UserDe
.buildCommitTime(buildCommitTime)
.planCode(currentLicense.getPlanCode())
.planName(currentLicense.getPlanName())
.hfSpaceId(GeneralSettingsService.hfSpaceId)
.isRunningOnHfSpaces(GeneralSettingsService.isRunningInHFSpaces)
.hfSpaceId(GeneralSettingsService.HF_SPACE_ID)
.isRunningOnHfSpaces(GeneralSettingsService.IS_RUNNING_IN_HFSPACES)
.isDemoHfSpace(GeneralSettingsService.IS_RUNNING_IN_DEMO_HF_SPACES)
.roles(roles)
.build())
.user(userDTO)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ public TestTemplateExecutionResultDTO runAdHocTest(@RequestBody RunAdhocTestRequ

Project project = projectRepository.getMandatoryById(request.getProjectId());

boolean usingInternalWorker = project.isUsingInternalWorker();
MLWorkerID workerID = usingInternalWorker ? MLWorkerID.INTERNAL : MLWorkerID.EXTERNAL;
MLWorkerID workerID = project.isUsingInternalWorker() ? MLWorkerID.INTERNAL : MLWorkerID.EXTERNAL;
if (mlWorkerWSService.isWorkerConnected(workerID)) {
MLWorkerWSRunAdHocTestParamDTO.MLWorkerWSRunAdHocTestParamDTOBuilder builder =
MLWorkerWSRunAdHocTestParamDTO.builder()
Expand Down
Loading

0 comments on commit d5eaa73

Please sign in to comment.