diff --git a/common/core/src/main/java/zingg/common/core/executor/Labeller.java b/common/core/src/main/java/zingg/common/core/executor/Labeller.java index a9ce811a4..7d054c966 100644 --- a/common/core/src/main/java/zingg/common/core/executor/Labeller.java +++ b/common/core/src/main/java/zingg/common/core/executor/Labeller.java @@ -2,6 +2,8 @@ import java.util.List; import java.util.Scanner; +import java.util.Objects; +import java.util.Set; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -19,22 +21,45 @@ public abstract class Labeller extends ZinggBase implements IPreprocessors { - public static final Integer QUIT_LABELING = 9; - public static final Integer INCREMENT = 1; + public enum LabelAction { + MATCH(1), + NO_MATCH(0), + NOT_SURE(2), + QUIT(9); + + private final int code; + + LabelAction(int code) { + this.code = code; + } + public int getCode() { + return code; + } + } + private static final int STAT_INCREMENT = 1; private static final long serialVersionUID = 1L; - protected static String name = "zingg.common.core.executor.Labeller"; + // protected static final String NAME = "zingg.common.core.executor.Labeller"; + private static final String LEFT_ANTI_JOIN = "left_anti"; public static final Log LOG = LogFactory.getLog(Labeller.class); - protected ITrainingDataModel trainingDataModel; - protected ILabelDataViewHelper labelDataViewHelper; - - public Labeller() { - setZinggOption(ZinggOptions.LABEL); + private final Scanner scanner = new Scanner(System.in); + private final ITrainingDataModel trainingDataModel; + private final ILabelDataViewHelper labelDataViewHelper; + public Labeller(ITrainingDataModel trainingDataModel, + ILabelDataViewHelper labelDataViewHelper) { + this.trainingDataModel = trainingDataModel; + this.labelDataViewHelper = labelDataViewHelper; } public void execute() throws ZinggClientException { + setZinggOption(ZinggOptions.LABEL); + try { - LabellerUtil labellerUtil = new LabellerUtil(); + LabellerUtil labellerUtil = new LabellerUtil<>(); LOG.info("Reading inputs for labelling phase ..."); + if(getMarkedRecords() == null) { + LOG.info("No marked records found. Initializing the marked records stat."); + return; + } getTrainingDataModel().setMarkedRecordsStat(getMarkedRecords()); ZFrame unmarkedRecords = getUnmarkedRecords(); ZFrame preprocessedUnmarkedRecords = preprocess(unmarkedRecords); @@ -45,12 +70,21 @@ public void execute() throws ZinggClientException { getTrainingDataModel().writeLabelledOutput(postProcessedLabelledRecords,args); } LOG.info("Finished labelling phase"); - } catch (Exception e) { - throw new ZinggClientException("Error in labelling phase ", e); + }catch(ZinggClientException e) { + LOG.error("Error while labelling records", e); + throw e; + + }catch (RuntimeException e) { + LOG.error("Unexpected error has occurred while labelling records", e); + throw new ZinggClientException("Unexpected error occurred while labelling records", e); + }finally { + scanner.close(); } + + } - + public ZFrame getUnmarkedRecords() { ZFrame unmarkedRecords = null; ZFrame markedRecords = null; @@ -58,137 +92,118 @@ public ZFrame getUnmarkedRecords() { unmarkedRecords = getPipeUtil().read(false, false, getModelHelper().getTrainingDataUnmarkedPipe(args)); try { markedRecords = getPipeUtil().read(false, false, getModelHelper().getTrainingDataMarkedPipe(args)); - } catch (Exception e) { - LOG.warn("No record has been marked yet"); - } catch (ZinggClientException zce) { - LOG.warn("No record has been marked yet"); + }catch (ZinggClientException zce) { + LOG.warn("No record has been marked yet", zce); } if (markedRecords != null ) { - unmarkedRecords = unmarkedRecords.join(markedRecords,ColName.CLUSTER_COLUMN, false, - "left_anti"); + unmarkedRecords = unmarkedRecords.join(markedRecords,ColName.CLUSTER_COLUMN, false,LEFT_ANTI_JOIN); getTrainingDataModel().setMarkedRecordsStat(markedRecords); } - } catch (Exception e) { - LOG.warn("No unmarked record for labelling"); - } catch (ZinggClientException e1) { - // TODO Auto-generated catch block - e1.printStackTrace(); + } catch (ZinggClientException e) { + // No marked records available, continuing with unmarked records only + LOG.error("Error while reading unmarked records for labelling", e); } return unmarkedRecords; } - public ZFrame processRecordsCli(ZFrame lines) throws ZinggClientException { + + public ZFrame processRecordsCli(ZFrame unmarkedRecords) throws ZinggClientException { LOG.info("Processing Records for CLI Labelling"); - if (lines != null && lines.count() > 0) { - getLabelDataViewHelper().printMarkedRecordsStat( - getTrainingDataModel().getPositivePairsCount(), - getTrainingDataModel().getNegativePairsCount(), - getTrainingDataModel().getNotSurePairsCount(), - getTrainingDataModel().getTotalCount() - ); - - lines = lines.cache(); -// List displayCols = getLabelDataViewHelper().getDisplayColumns(lines, args); - ZidAndFieldDefSelector zidAndFieldDefSelector = new ZidAndFieldDefSelector(args.getFieldDefinition(), false, args.getShowConcise()); - //have to introduce as snowframe can not handle row.getAs with column - //name and row and lines are out of order for the code to work properly - //snow getAsString expects row to have same struc as dataframe which is - //not happening - ZFrame clusterIdZFrame = getLabelDataViewHelper().getClusterIdsFrame(lines); - List clusterIDs = getLabelDataViewHelper().getClusterIds(clusterIdZFrame); + if (unmarkedRecords != null && !unmarkedRecords.isEmpty()) { + printStatistics(); + + unmarkedRecords = unmarkedRecords.cache(); + ZidAndFieldDefSelector zidAndFieldDefSelector = new ZidAndFieldDefSelector(args.getFieldDefinition(), false, args.getShowConcise()); + //have to introduce as snowframe can not handle row.getAs with column + //name and row and lines are out of order for the code to work properly + //snow getAsString expects row to have same struc as dataframe which is + //not happening + ZFrame clusterIdZFrame = getLabelDataViewHelper().getClusterIdsFrame(unmarkedRecords); + List clusterIDs = getLabelDataViewHelper().getClusterIds(clusterIdZFrame); + Objects.requireNonNull(clusterIDs, "Cluster IDs cannot be null"); try { double score; double prediction; ZFrame updatedRecords = null; int selectedOption = -1; - String msg1, msg2; + String progressMessage, predictionMessage; int totalPairs = clusterIDs.size(); for (int index = 0; index < totalPairs; index++) { - ZFrame currentPair = getLabelDataViewHelper().getCurrentPair(lines, index, clusterIDs, clusterIdZFrame); + ZFrame currentPair = getLabelDataViewHelper().getCurrentPair(unmarkedRecords, index, clusterIDs, clusterIdZFrame); + Objects.requireNonNull(currentPair, "Current pair to label cannot be null"); score = getLabelDataViewHelper().getScore(currentPair); prediction = getLabelDataViewHelper().getPrediction(currentPair); - msg1 = getLabelDataViewHelper().getMsg1(index, totalPairs); - msg2 = getLabelDataViewHelper().getMsg2(prediction, score); - //String msgHeader = msg1 + msg2; - -// selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), msg1, msg2); - selectedOption = displayRecordsAndGetUserInput(currentPair.select(zidAndFieldDefSelector.getCols()), msg1, msg2); - getTrainingDataModel().updateLabellerStat(selectedOption, INCREMENT); - getLabelDataViewHelper().printMarkedRecordsStat( - getTrainingDataModel().getPositivePairsCount(), - getTrainingDataModel().getNegativePairsCount(), - getTrainingDataModel().getNotSurePairsCount(), - getTrainingDataModel().getTotalCount() - ); - if (selectedOption == QUIT_LABELING) { + progressMessage = getLabelDataViewHelper().getMsg1(index, totalPairs); + predictionMessage = getLabelDataViewHelper().getMsg2(prediction, score); + selectedOption = displayRecordsAndGetUserInput(currentPair.select(zidAndFieldDefSelector.getCols()), progressMessage, predictionMessage); + getTrainingDataModel().updateLabellerStat(selectedOption, STAT_INCREMENT); + printStatistics(); + if (selectedOption == LabelAction.QUIT.getCode()) { LOG.info("User has quit in the middle. Updating the records."); break; } updatedRecords = getTrainingDataModel().updateRecords(selectedOption, currentPair, updatedRecords); } - LOG.warn("Processing finished."); + LOG.info("Processing finished."); return updatedRecords; - } catch (Exception e) { - LOG.warn("Labelling error has occured " + e.getMessage()); - throw new ZinggClientException("An error has occured while Labelling.", e); + } catch (ZinggClientException e) { + LOG.error("Error while processing records for labelling", e); + throw e; + } + catch (RuntimeException e) { + LOG.error("Unexpected error has occurred while processing records for labelling", e); + throw new ZinggClientException("Unexpected error occurred while processing records for labelling", e); } } else { LOG.info("It seems there are no unmarked records at this moment. Please run findTrainingData job to build some pairs to be labelled and then run this labeler."); return null; } } - protected int displayRecordsAndGetUserInput(ZFrame records, String preMessage, String postMessage) throws ZinggClientException { getLabelDataViewHelper().displayRecords(records, preMessage, postMessage); int selection = readCliInput(); return selection; } - - - int readCliInput() { - Scanner sc = new Scanner(System.in); - - while (!sc.hasNext("[0129]")) { - sc.next(); - System.out.println("Nope, please enter one of the allowed options!"); + private boolean isValidOption(String input){ + try { + int code = Integer.parseInt(input); + return java.util.Arrays.stream(LabelAction.values()) + .anyMatch(action -> action.getCode() == code); + } catch (NumberFormatException e) { + return false; } - String word = sc.next(); - int selection = Integer.parseInt(word); - // sc.close(); + } + private int readCliInput ()throws ZinggClientException{ - return selection; + while (true) { + if(!scanner.hasNext()) { + throw new ZinggClientException("No input received from user"); + } + String userInput = scanner.next().trim(); + if (isValidOption(userInput)) { + return Integer.parseInt(userInput); + } + System.out.println("Invalid input. Allowed values: 0, 1, 2, 9"); + } } @Override public ITrainingDataModel getTrainingDataModel() { - if (trainingDataModel==null) { - this.trainingDataModel = new TrainingDataModel(getContext(), getClientOptions()); - } return trainingDataModel; } - - public void setTrainingDataModel(ITrainingDataModel trainingDataModel) { - this.trainingDataModel = trainingDataModel; - } - - public ILabelDataViewHelper getLabelDataViewHelper() { - if(labelDataViewHelper==null) { - labelDataViewHelper = new LabelDataViewHelper(getContext(), getClientOptions()); - labelDataViewHelper.initVerticalDisplayUtility(getDfObjectUtil()); - } return labelDataViewHelper; } - - public void setLabelDataViewHelper(ILabelDataViewHelper labelDataViewHelper) { - this.labelDataViewHelper = labelDataViewHelper; + private void printStatistics() { + getLabelDataViewHelper().printMarkedRecordsStat( + getTrainingDataModel().getPositivePairsCount(), + getTrainingDataModel().getNegativePairsCount(), + getTrainingDataModel().getNotSurePairsCount(), + getTrainingDataModel().getTotalCount()); } - protected abstract DFObjectUtil getDfObjectUtil(); } - -