Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FR] Make sure formType is the same as what we shipped in 3.0 for service v2.0 #20785

Merged
merged 1 commit into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ namespace Azure.AI.FormRecognizer.Training
/// </summary>
public class CreateComposedModelOperation : CreateCustomFormModelOperation
{
internal CreateComposedModelOperation(string location, FormRecognizerRestClient allOperations, ClientDiagnostics diagnostics) : base(location, allOperations, diagnostics) { }
internal CreateComposedModelOperation(string location, FormRecognizerRestClient allOperations, ClientDiagnostics diagnostics, FormRecognizerClientOptions.ServiceVersion serviceVersion)
: base(location, allOperations, diagnostics, serviceVersion) { }

/// <summary>
/// Initializes a new instance of the <see cref="CreateComposedModelOperation"/> class which
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ public class CreateCustomFormModelOperation : Operation<CustomFormModel>
/// <summary>Provides tools for exception creation in case of failure.</summary>
private readonly ClientDiagnostics _diagnostics;

/// <summary>Service version used in this client.</summary>
private readonly FormRecognizerClientOptions.ServiceVersion _serviceVersion;

private RequestFailedException _requestFailedException;

/// <summary>The last HTTP response received from the server. <c>null</c> until the first response is received.</summary>
Expand Down Expand Up @@ -105,10 +108,15 @@ public override ValueTask<Response<CustomFormModel>> WaitForCompletionAsync(Canc
public override ValueTask<Response<CustomFormModel>> WaitForCompletionAsync(TimeSpan pollingInterval, CancellationToken cancellationToken = default) =>
this.DefaultWaitForCompletionAsync(pollingInterval, cancellationToken);

internal CreateCustomFormModelOperation(string location, FormRecognizerRestClient allOperations, ClientDiagnostics diagnostics)
internal CreateCustomFormModelOperation(
string location,
FormRecognizerRestClient allOperations,
ClientDiagnostics diagnostics,
FormRecognizerClientOptions.ServiceVersion serviceVersion)
{
_serviceClient = allOperations;
_diagnostics = diagnostics;
_serviceVersion = serviceVersion;

// TODO: validate this
// https://github.com/Azure/azure-sdk-for-net/issues/10385
Expand All @@ -128,6 +136,7 @@ public CreateCustomFormModelOperation(string operationId, FormTrainingClient cli
Id = operationId;
_diagnostics = client.Diagnostics;
_serviceClient = client.ServiceClient;
_serviceVersion = client.ServiceVersion;
}

/// <summary>
Expand Down Expand Up @@ -185,7 +194,7 @@ private async ValueTask<Response> UpdateStatusAsync(bool async, CancellationToke
if (update.Value.ModelInfo.Status == CustomFormModelStatus.Ready)
{
// We need to first assign a value and then mark the operation as completed to avoid a race condition with the getter in Value
_value = new CustomFormModel(update.Value);
_value = new CustomFormModel(update.Value, _serviceVersion);
_hasCompleted = true;
}
else if (update.Value.ModelInfo.Status == CustomFormModelStatus.Invalid)
Expand Down
25 changes: 16 additions & 9 deletions sdk/formrecognizer/Azure.AI.FormRecognizer/src/CustomFormModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ namespace Azure.AI.FormRecognizer.Training
/// </summary>
public class CustomFormModel
{
internal CustomFormModel(Model model)
internal CustomFormModel(Model model, FormRecognizerClientOptions.ServiceVersion serviceVersion)
{
ModelId = model.ModelInfo.ModelId;
ModelName = model.ModelInfo.ModelName;
Status = model.ModelInfo.Status;
TrainingStartedOn = model.ModelInfo.TrainingStartedOn;
TrainingCompletedOn = model.ModelInfo.TrainingCompletedOn;
Submodels = ConvertToSubmodels(model);
Submodels = ConvertToSubmodels(model, serviceVersion);
TrainingDocuments = ConvertToTrainingDocuments(model);
Errors = model.TrainResult?.Errors ?? new List<FormRecognizerError>();
Properties = model.ModelInfo.Properties ?? new CustomFormModelProperties();
Expand Down Expand Up @@ -106,13 +106,13 @@ internal CustomFormModel(
/// </summary>
public IReadOnlyList<FormRecognizerError> Errors { get; }

private static IReadOnlyList<CustomFormSubmodel> ConvertToSubmodels(Model model)
private static IReadOnlyList<CustomFormSubmodel> ConvertToSubmodels(Model model, FormRecognizerClientOptions.ServiceVersion serviceVersion = default)
{
if (model.Keys != null)
return ConvertFromUnlabeled(model);

if (model.TrainResult != null)
return ConvertFromLabeled(model);
return ConvertFromLabeled(model, serviceVersion);

if (model.ComposedTrainResults != null)
return ConvertFromLabeledComposedModel(model);
Expand Down Expand Up @@ -142,13 +142,20 @@ private static IReadOnlyList<CustomFormSubmodel> ConvertFromUnlabeled(Model mode
return subModels;
}

private static IReadOnlyList<CustomFormSubmodel> ConvertFromLabeled(Model model)
private static IReadOnlyList<CustomFormSubmodel> ConvertFromLabeled(Model model, FormRecognizerClientOptions.ServiceVersion serviceVersion = default)
{
string formType = string.Empty;
if (string.IsNullOrEmpty(model.ModelInfo.ModelName))
formType = $"custom:{model.ModelInfo.ModelId}";
string formType;
if (serviceVersion == FormRecognizerClientOptions.ServiceVersion.V2_0)
{
formType = $"form-{model.ModelInfo.ModelId}";
}
else
formType = $"custom:{model.ModelInfo.ModelName}";
{
if (string.IsNullOrEmpty(model.ModelInfo.ModelName))
formType = $"custom:{model.ModelInfo.ModelId}";
else
formType = $"custom:{model.ModelInfo.ModelName}";
}

return new List<CustomFormSubmodel> {
new CustomFormSubmodel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ public class FormTrainingClient
/// <summary>Provides tools for exception creation in case of failure.</summary>
internal readonly ClientDiagnostics Diagnostics;

/// <summary>Service version used in this client.</summary>
internal readonly FormRecognizerClientOptions.ServiceVersion ServiceVersion = FormRecognizerClientOptions.LatestVersion;

/// <summary>
/// Initializes a new instance of the <see cref="FormTrainingClient"/> class.
/// </summary>
Expand Down Expand Up @@ -65,8 +68,9 @@ public FormTrainingClient(Uri endpoint, AzureKeyCredential credential, FormRecog
Argument.AssertNotNull(options, nameof(options));

Diagnostics = new ClientDiagnostics(options);
ServiceVersion = options.Version;
HttpPipeline pipeline = HttpPipelineBuilder.Build(options, new AzureKeyCredentialPolicy(credential, Constants.AuthorizationHeader));
ServiceClient = new FormRecognizerRestClient(Diagnostics, pipeline, endpoint.AbsoluteUri, FormRecognizerClientOptions.GetVersionString(options.Version));
ServiceClient = new FormRecognizerRestClient(Diagnostics, pipeline, endpoint.AbsoluteUri, FormRecognizerClientOptions.GetVersionString(ServiceVersion));
}

/// <summary>
Expand Down Expand Up @@ -100,8 +104,9 @@ public FormTrainingClient(Uri endpoint, TokenCredential credential, FormRecogniz
Argument.AssertNotNull(options, nameof(options));

Diagnostics = new ClientDiagnostics(options);
ServiceVersion = options.Version;
var pipeline = HttpPipelineBuilder.Build(options, new BearerTokenAuthenticationPolicy(credential, Constants.DefaultCognitiveScope));
ServiceClient = new FormRecognizerRestClient(Diagnostics, pipeline, endpoint.AbsoluteUri, FormRecognizerClientOptions.GetVersionString(options.Version));
ServiceClient = new FormRecognizerRestClient(Diagnostics, pipeline, endpoint.AbsoluteUri, FormRecognizerClientOptions.GetVersionString(ServiceVersion));
}

#region Training
Expand Down Expand Up @@ -139,7 +144,7 @@ public virtual TrainingOperation StartTraining(Uri trainingFilesUri, bool useTra
};

ResponseWithHeaders<FormRecognizerTrainCustomModelAsyncHeaders> response = ServiceClient.TrainCustomModelAsync(trainRequest, cancellationToken);
return new TrainingOperation(response.Headers.Location, ServiceClient, Diagnostics);
return new TrainingOperation(response.Headers.Location, ServiceClient, Diagnostics, ServiceVersion);
}
catch (Exception e)
{
Expand Down Expand Up @@ -182,7 +187,7 @@ public virtual async Task<TrainingOperation> StartTrainingAsync(Uri trainingFile
};

ResponseWithHeaders<FormRecognizerTrainCustomModelAsyncHeaders> response = await ServiceClient.TrainCustomModelAsyncAsync(trainRequest, cancellationToken).ConfigureAwait(false);
return new TrainingOperation(response.Headers.Location, ServiceClient, Diagnostics);
return new TrainingOperation(response.Headers.Location, ServiceClient, Diagnostics, ServiceVersion);
}
catch (Exception e)
{
Expand Down Expand Up @@ -223,7 +228,7 @@ public virtual TrainingOperation StartTraining(Uri trainingFilesUri, bool useTra
};

ResponseWithHeaders<FormRecognizerTrainCustomModelAsyncHeaders> response = ServiceClient.TrainCustomModelAsync(trainRequest, cancellationToken);
return new TrainingOperation(response.Headers.Location, ServiceClient, Diagnostics);
return new TrainingOperation(response.Headers.Location, ServiceClient, Diagnostics, ServiceVersion);
}
catch (Exception e)
{
Expand Down Expand Up @@ -264,7 +269,7 @@ public virtual async Task<TrainingOperation> StartTrainingAsync(Uri trainingFile
};

ResponseWithHeaders<FormRecognizerTrainCustomModelAsyncHeaders> response = await ServiceClient.TrainCustomModelAsyncAsync(trainRequest, cancellationToken).ConfigureAwait(false);
return new TrainingOperation(response.Headers.Location, ServiceClient, Diagnostics);
return new TrainingOperation(response.Headers.Location, ServiceClient, Diagnostics, ServiceVersion);
}
catch (Exception e)
{
Expand Down Expand Up @@ -306,7 +311,7 @@ public virtual CreateComposedModelOperation StartCreateComposedModel(IEnumerable
composeRequest.ModelName = modelName;

ResponseWithHeaders<FormRecognizerComposeCustomModelsAsyncHeaders> response = ServiceClient.ComposeCustomModelsAsync(composeRequest, cancellationToken);
return new CreateComposedModelOperation(response.Headers.Location, ServiceClient, Diagnostics);
return new CreateComposedModelOperation(response.Headers.Location, ServiceClient, Diagnostics, ServiceVersion);
}
catch (Exception e)
{
Expand Down Expand Up @@ -344,7 +349,7 @@ public virtual async Task<CreateComposedModelOperation> StartCreateComposedModel
composeRequest.ModelName = modelName;

ResponseWithHeaders<FormRecognizerComposeCustomModelsAsyncHeaders> response = await ServiceClient.ComposeCustomModelsAsyncAsync(composeRequest, cancellationToken).ConfigureAwait(false);
return new CreateComposedModelOperation(response.Headers.Location, ServiceClient, Diagnostics);
return new CreateComposedModelOperation(response.Headers.Location, ServiceClient, Diagnostics, ServiceVersion);
}
catch (Exception e)
{
Expand Down Expand Up @@ -376,7 +381,7 @@ public virtual Response<CustomFormModel> GetCustomModel(string modelId, Cancella
Guid guid = ClientCommon.ValidateModelId(modelId, nameof(modelId));

Response<Model> response = ServiceClient.GetCustomModel(guid, includeKeys: true, cancellationToken);
return Response.FromValue(new CustomFormModel(response.Value), response.GetRawResponse());
return Response.FromValue(new CustomFormModel(response.Value, ServiceVersion), response.GetRawResponse());
}
catch (Exception e)
{
Expand Down Expand Up @@ -404,7 +409,7 @@ public virtual async Task<Response<CustomFormModel>> GetCustomModelAsync(string
Guid guid = ClientCommon.ValidateModelId(modelId, nameof(modelId));

Response<Model> response = await ServiceClient.GetCustomModelAsync(guid, includeKeys: true, cancellationToken).ConfigureAwait(false);
return Response.FromValue(new CustomFormModel(response.Value), response.GetRawResponse());
return Response.FromValue(new CustomFormModel(response.Value, ServiceVersion), response.GetRawResponse());
}
catch (Exception e)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ namespace Azure.AI.FormRecognizer.Training
/// </summary>
public class TrainingOperation : CreateCustomFormModelOperation
{
internal TrainingOperation(string location, FormRecognizerRestClient allOperations, ClientDiagnostics diagnostics) : base(location, allOperations, diagnostics) { }
internal TrainingOperation(string location, FormRecognizerRestClient allOperations, ClientDiagnostics diagnostics, FormRecognizerClientOptions.ServiceVersion serviceVersion)
: base(location, allOperations, diagnostics, serviceVersion) { }

/// <summary>
/// Initializes a new instance of the <see cref="TrainingOperation"/> class which
Expand Down