Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 108 additions & 93 deletions common/core/src/main/java/zingg/common/core/executor/Labeller.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import java.util.List;
import java.util.Scanner;
import java.util.Objects;
import java.util.Set;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand All @@ -19,22 +21,45 @@

public abstract class Labeller<S,D,R,C,T> extends ZinggBase<S,D,R,C,T> implements IPreprocessors<S,D,R,C,T> {

public static final Integer QUIT_LABELING = 9;
public static final Integer INCREMENT = 1;
public enum LabelAction {
MATCH(1),
NO_MATCH(0),
NOT_SURE(2),
QUIT(9);

private final int code;

LabelAction(int code) {
this.code = code;
}
public int getCode() {
return code;
}
}
private static final int STAT_INCREMENT = 1;
private static final long serialVersionUID = 1L;
protected static String name = "zingg.common.core.executor.Labeller";
// protected static final String NAME = "zingg.common.core.executor.Labeller";
private static final String LEFT_ANTI_JOIN = "left_anti";
public static final Log LOG = LogFactory.getLog(Labeller.class);
protected ITrainingDataModel<S, D, R, C> trainingDataModel;
protected ILabelDataViewHelper<S, D, R, C> labelDataViewHelper;

public Labeller() {
setZinggOption(ZinggOptions.LABEL);
private final Scanner scanner = new Scanner(System.in);
private final ITrainingDataModel<S, D, R, C> trainingDataModel;
private final ILabelDataViewHelper<S, D, R, C> labelDataViewHelper;
public Labeller(ITrainingDataModel<S,D,R,C> trainingDataModel,
ILabelDataViewHelper<S,D,R,C> labelDataViewHelper) {
this.trainingDataModel = trainingDataModel;
this.labelDataViewHelper = labelDataViewHelper;
}

public void execute() throws ZinggClientException {
setZinggOption(ZinggOptions.LABEL);

try {
LabellerUtil<D, R, C> labellerUtil = new LabellerUtil<D, R, C>();
LabellerUtil<D, R, C> labellerUtil = new LabellerUtil<>();
LOG.info("Reading inputs for labelling phase ...");
if(getMarkedRecords() == null) {
LOG.info("No marked records found. Initializing the marked records stat.");
return;
}
getTrainingDataModel().setMarkedRecordsStat(getMarkedRecords());
ZFrame<D,R,C> unmarkedRecords = getUnmarkedRecords();
ZFrame<D, R, C> preprocessedUnmarkedRecords = preprocess(unmarkedRecords);
Expand All @@ -45,150 +70,140 @@ public void execute() throws ZinggClientException {
getTrainingDataModel().writeLabelledOutput(postProcessedLabelledRecords,args);
}
LOG.info("Finished labelling phase");
} catch (Exception e) {
throw new ZinggClientException("Error in labelling phase ", e);
}catch(ZinggClientException e) {
LOG.error("Error while labelling records", e);
throw e;

}catch (RuntimeException e) {
LOG.error("Unexpected error has occurred while labelling records", e);
throw new ZinggClientException("Unexpected error occurred while labelling records", e);
}finally {
scanner.close();
}


}


public ZFrame<D,R,C> getUnmarkedRecords() {
ZFrame<D,R,C> unmarkedRecords = null;
ZFrame<D,R,C> markedRecords = null;
try {
unmarkedRecords = getPipeUtil().read(false, false, getModelHelper().getTrainingDataUnmarkedPipe(args));
try {
markedRecords = getPipeUtil().read(false, false, getModelHelper().getTrainingDataMarkedPipe(args));
} catch (Exception e) {
LOG.warn("No record has been marked yet");
} catch (ZinggClientException zce) {
LOG.warn("No record has been marked yet");
}catch (ZinggClientException zce) {
LOG.warn("No record has been marked yet", zce);
}
if (markedRecords != null ) {
unmarkedRecords = unmarkedRecords.join(markedRecords,ColName.CLUSTER_COLUMN, false,
"left_anti");
unmarkedRecords = unmarkedRecords.join(markedRecords,ColName.CLUSTER_COLUMN, false,LEFT_ANTI_JOIN);
getTrainingDataModel().setMarkedRecordsStat(markedRecords);
}
} catch (Exception e) {
LOG.warn("No unmarked record for labelling");
} catch (ZinggClientException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
} catch (ZinggClientException e) {
// No marked records available, continuing with unmarked records only
LOG.error("Error while reading unmarked records for labelling", e);
}
return unmarkedRecords;
}

public ZFrame<D,R,C> processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {

public ZFrame<D,R,C> processRecordsCli(ZFrame<D,R,C> unmarkedRecords) throws ZinggClientException {
LOG.info("Processing Records for CLI Labelling");
if (lines != null && lines.count() > 0) {
getLabelDataViewHelper().printMarkedRecordsStat(
getTrainingDataModel().getPositivePairsCount(),
getTrainingDataModel().getNegativePairsCount(),
getTrainingDataModel().getNotSurePairsCount(),
getTrainingDataModel().getTotalCount()
);

lines = lines.cache();
// List<C> displayCols = getLabelDataViewHelper().getDisplayColumns(lines, args);
ZidAndFieldDefSelector zidAndFieldDefSelector = new ZidAndFieldDefSelector(args.getFieldDefinition(), false, args.getShowConcise());
//have to introduce as snowframe can not handle row.getAs with column
//name and row and lines are out of order for the code to work properly
//snow getAsString expects row to have same struc as dataframe which is
//not happening
ZFrame<D,R,C> clusterIdZFrame = getLabelDataViewHelper().getClusterIdsFrame(lines);
List<R> clusterIDs = getLabelDataViewHelper().getClusterIds(clusterIdZFrame);
if (unmarkedRecords != null && !unmarkedRecords.isEmpty()) {
printStatistics();

unmarkedRecords = unmarkedRecords.cache();
ZidAndFieldDefSelector zidAndFieldDefSelector = new ZidAndFieldDefSelector(args.getFieldDefinition(), false, args.getShowConcise());
//have to introduce as snowframe can not handle row.getAs with column
//name and row and lines are out of order for the code to work properly
//snow getAsString expects row to have same struc as dataframe which is
//not happening
ZFrame<D,R,C> clusterIdZFrame = getLabelDataViewHelper().getClusterIdsFrame(unmarkedRecords);
List<R> clusterIDs = getLabelDataViewHelper().getClusterIds(clusterIdZFrame);
Objects.requireNonNull(clusterIDs, "Cluster IDs cannot be null");
try {
double score;
double prediction;
ZFrame<D,R,C> updatedRecords = null;
int selectedOption = -1;
String msg1, msg2;
String progressMessage, predictionMessage;
int totalPairs = clusterIDs.size();

for (int index = 0; index < totalPairs; index++) {
ZFrame<D,R,C> currentPair = getLabelDataViewHelper().getCurrentPair(lines, index, clusterIDs, clusterIdZFrame);
ZFrame<D,R,C> currentPair = getLabelDataViewHelper().getCurrentPair(unmarkedRecords, index, clusterIDs, clusterIdZFrame);
Objects.requireNonNull(currentPair, "Current pair to label cannot be null");

score = getLabelDataViewHelper().getScore(currentPair);
prediction = getLabelDataViewHelper().getPrediction(currentPair);

msg1 = getLabelDataViewHelper().getMsg1(index, totalPairs);
msg2 = getLabelDataViewHelper().getMsg2(prediction, score);
//String msgHeader = msg1 + msg2;

// selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), msg1, msg2);
selectedOption = displayRecordsAndGetUserInput(currentPair.select(zidAndFieldDefSelector.getCols()), msg1, msg2);
getTrainingDataModel().updateLabellerStat(selectedOption, INCREMENT);
getLabelDataViewHelper().printMarkedRecordsStat(
getTrainingDataModel().getPositivePairsCount(),
getTrainingDataModel().getNegativePairsCount(),
getTrainingDataModel().getNotSurePairsCount(),
getTrainingDataModel().getTotalCount()
);
if (selectedOption == QUIT_LABELING) {
progressMessage = getLabelDataViewHelper().getMsg1(index, totalPairs);
predictionMessage = getLabelDataViewHelper().getMsg2(prediction, score);
selectedOption = displayRecordsAndGetUserInput(currentPair.select(zidAndFieldDefSelector.getCols()), progressMessage, predictionMessage);
getTrainingDataModel().updateLabellerStat(selectedOption, STAT_INCREMENT);
printStatistics();
if (selectedOption == LabelAction.QUIT.getCode()) {
LOG.info("User has quit in the middle. Updating the records.");
break;
}
updatedRecords = getTrainingDataModel().updateRecords(selectedOption, currentPair, updatedRecords);
}
LOG.warn("Processing finished.");
LOG.info("Processing finished.");
return updatedRecords;
} catch (Exception e) {
LOG.warn("Labelling error has occured " + e.getMessage());
throw new ZinggClientException("An error has occured while Labelling.", e);
} catch (ZinggClientException e) {
LOG.error("Error while processing records for labelling", e);
throw e;
}
catch (RuntimeException e) {
LOG.error("Unexpected error has occurred while processing records for labelling", e);
throw new ZinggClientException("Unexpected error occurred while processing records for labelling", e);
}
} else {
LOG.info("It seems there are no unmarked records at this moment. Please run findTrainingData job to build some pairs to be labelled and then run this labeler.");
return null;
}
}


protected int displayRecordsAndGetUserInput(ZFrame<D,R,C> records, String preMessage, String postMessage) throws ZinggClientException {
getLabelDataViewHelper().displayRecords(records, preMessage, postMessage);
int selection = readCliInput();
return selection;
}


int readCliInput() {
Scanner sc = new Scanner(System.in);

while (!sc.hasNext("[0129]")) {
sc.next();
System.out.println("Nope, please enter one of the allowed options!");
private boolean isValidOption(String input){
try {
int code = Integer.parseInt(input);
return java.util.Arrays.stream(LabelAction.values())
.anyMatch(action -> action.getCode() == code);
} catch (NumberFormatException e) {
return false;
}
String word = sc.next();
int selection = Integer.parseInt(word);
// sc.close();
}
private int readCliInput ()throws ZinggClientException{

return selection;
while (true) {
if(!scanner.hasNext()) {
throw new ZinggClientException("No input received from user");
}
String userInput = scanner.next().trim();
if (isValidOption(userInput)) {
return Integer.parseInt(userInput);
}
System.out.println("Invalid input. Allowed values: 0, 1, 2, 9");
}
}

@Override
public ITrainingDataModel<S, D, R, C> getTrainingDataModel() {
if (trainingDataModel==null) {
this.trainingDataModel = new TrainingDataModel<S, D, R, C, T>(getContext(), getClientOptions());
}
return trainingDataModel;
}

public void setTrainingDataModel(ITrainingDataModel<S, D, R, C> trainingDataModel) {
this.trainingDataModel = trainingDataModel;
}


public ILabelDataViewHelper<S, D, R, C> getLabelDataViewHelper() {
if(labelDataViewHelper==null) {
labelDataViewHelper = new LabelDataViewHelper<S,D,R,C,T>(getContext(), getClientOptions());
labelDataViewHelper.initVerticalDisplayUtility(getDfObjectUtil());
}
return labelDataViewHelper;
}

public void setLabelDataViewHelper(ILabelDataViewHelper<S, D, R, C> labelDataViewHelper) {
this.labelDataViewHelper = labelDataViewHelper;
private void printStatistics() {
getLabelDataViewHelper().printMarkedRecordsStat(
getTrainingDataModel().getPositivePairsCount(),
getTrainingDataModel().getNegativePairsCount(),
getTrainingDataModel().getNotSurePairsCount(),
getTrainingDataModel().getTotalCount());
}

protected abstract DFObjectUtil<S, D, R, C> getDfObjectUtil();
}