SSAGES  0.1
A MetaDynamics Package
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Groups Pages
Umbrella.cpp
1 
21 #include "Umbrella.h"
22 #include "Snapshot.h"
23 #include "CVs/CVManager.h"
24 #include "Validator/ObjectRequirement.h"
25 #include "Drivers/DriverException.h"
26 #include "schema.h"
27 #include <iostream>
28 
29 using namespace Json;
30 
31 namespace SSAGES
32 {
33  void Umbrella::PreSimulation(Snapshot* /* snapshot */, const CVManager& cvmanager)
34  {
35  if(comm_.rank() == 0)
36  {
37  if(append_)
38  umbrella_.open(filename_.c_str(), std::ofstream::out | std::ofstream::app);
39  else
40  {
41  // Write out header.
42  umbrella_.open(filename_.c_str(), std::ofstream::out);
43  umbrella_ << "#";
44  umbrella_ << "Iteration ";
45 
46  auto cvs = cvmanager.GetCVs(cvmask_);
47  for(size_t i = 0; i < cvs.size(); ++i)
48  umbrella_ << "cv_" + std::to_string(i) << " ";
49 
50  for(size_t i = 0; i < cvs.size() - 1; ++i)
51  umbrella_ << "center_" + std::to_string(i) << " ";
52  umbrella_ << "center_" + std::to_string(cvs.size() - 1) << std::endl;
53  }
54  }
55  }
56 
57  void Umbrella::PostIntegration(Snapshot* snapshot, const CVManager& cvmanager)
58  {
59  // Get necessary info.
60  auto cvs = cvmanager.GetCVs(cvmask_);
61  auto& forces = snapshot->GetForces();
62  auto& virial = snapshot->GetVirial();
63  auto& H = snapshot->GetHMatrix();
64 
65  for(size_t i = 0; i < cvs.size(); ++i)
66  {
67  // Get current CV and gradient.
68  auto& cv = cvs[i];
69  auto& grad = cv->GetGradient();
70  auto& boxgrad = cv->GetBoxGradient();
71  // Compute dV/dCV.
72  auto center = GetCurrentCenter(snapshot->GetIteration(), i);
73  auto D = kspring_[i]*cv->GetDifference(center);
74 
75  // Update forces.
76  for(size_t j = 0; j < forces.size(); ++j)
77  forces[j] -= D*grad[j];
78 
79  // Update virial.
80  virial += D*boxgrad;
81  }
82 
83  if(snapshot->GetIteration() % outfreq_ == 0)
84  PrintUmbrella(cvs, snapshot->GetIteration());
85  }
86 
87  void Umbrella::PostSimulation(Snapshot*, const CVManager&)
88  {
89  if(comm_.rank() ==0)
90  umbrella_.close();
91  }
92 
93  void Umbrella::PrintUmbrella(const CVList& cvs, uint iteration)
94  {
95  if(comm_.rank() ==0)
96  {
97  umbrella_.precision(8);
98  umbrella_ << iteration << " ";
99 
100  // Print out CV values first.
101  for(auto& cv : cvs)
102  umbrella_ << cv->GetValue() << " ";
103 
104  // Print out target (center) of each CV.
105  for(size_t i = 0; i < cvs.size() - 1; ++i)
106  umbrella_ << GetCurrentCenter(iteration, i) << " ";
107  umbrella_ << GetCurrentCenter(iteration, cvs.size() - 1);
108 
109  umbrella_ << std::endl;
110  }
111  }
112 
113  Umbrella* Umbrella::Build(const Json::Value& json,
114  const MPI_Comm& world,
115  const MPI_Comm& comm,
116  const std::string& path)
117  {
118  ObjectRequirement validator;
119  Value schema;
120  Reader reader;
121 
122  reader.parse(JsonSchema::UmbrellaMethod, schema);
123  validator.Parse(schema, path);
124 
125  // Validate inputs.
126  validator.Validate(json, path);
127  if(validator.HasErrors())
128  throw BuildException(validator.GetErrors());
129 
130  //TODO walker id should be obtainable in method as
131  // opposed to calculated like this.
132  uint wid = mxx::comm(world).rank()/mxx::comm(comm).size();
133  bool ismulti = mxx::comm(world).size() > mxx::comm(comm).size();
134  uint wcount = mxx::comm(world).size() / mxx::comm(comm).size();
135 
136  std::vector<std::vector<double>> ksprings;
137  for(auto& s : json["ksprings"])
138  {
139  std::vector<double> kspring;
140  if(s.isArray())
141  for(auto& k : s)
142  kspring.push_back(k.asDouble());
143  else
144  kspring.push_back(s.asDouble());
145 
146  ksprings.push_back(kspring);
147  }
148 
149  std::vector<std::vector<double>> centers0, centers1;
150  if(json.isMember("centers"))
151  {
152  for(auto& s : json["centers"])
153  {
154  std::vector<double> center;
155  if(s.isArray())
156  for(auto& k : s)
157  center.push_back(k.asDouble());
158  else
159  center.push_back(s.asDouble());
160 
161  centers0.push_back(center);
162  }
163  }
164  else if(json.isMember("centers0") && json.isMember("centers1") && json.isMember("timesteps"))
165  {
166  for(auto& s : json["centers0"])
167  {
168  std::vector<double> center;
169  if(s.isArray())
170  for(auto& k : s)
171  center.push_back(k.asDouble());
172  else
173  center.push_back(s.asDouble());
174 
175  centers0.push_back(center);
176  }
177 
178  for(auto& s : json["centers1"])
179  {
180  std::vector<double> center;
181  if(s.isArray())
182  for(auto& k : s)
183  center.push_back(k.asDouble());
184  else
185  center.push_back(s.asDouble());
186 
187  centers1.push_back(center);
188  }
189  }
190  else
191  throw BuildException({"Either \"centers\" or \"timesteps\", \"centers0\" and \"centers1\" must be defined for umbrella."});
192 
193  if(ksprings[0].size() != centers0[0].size())
194  throw BuildException({"Need to define a spring for every center or a center for every spring!"});
195 
196  // If only one set of center/ksprings are specified. Fill it up for multi.
197  if(ismulti)
198  {
199  if(ksprings.size() == 1)
200  for(size_t i = 1; i < wcount; ++i)
201  ksprings.push_back(ksprings[0]);
202  else if(ksprings.size() != wcount)
203  throw std::invalid_argument(path + ": Multi-walker simulations requires that the number of \"ksprings\" match the number of walkers.");
204  if(centers0.size() == 1)
205  for(size_t i = 1; i < wcount; ++i)
206  centers0.push_back(centers0[0]);
207  else if(centers0.size() != wcount)
208  throw std::invalid_argument(path + ": Multi-walker simulations requires that the number of \"centers\"/\"centers0\" match the number of walkers.");
209  if(centers1.size() == 1)
210  for(size_t i = 1; i < wcount; ++i)
211  centers1.push_back(centers1[0]);
212  else if(centers1.size()) // centers1 is optional.
213  throw std::invalid_argument(path + ": Multi-walker simulations requires that the number of \"centers1\" match the number of walkers.");
214  }
215 
216  auto freq = json.get("frequency", 1).asInt();
217 
218  uint timesteps = 0;
219  if(json.isMember("timesteps"))
220  {
221  if(json["timesteps"].isArray())
222  timesteps = json["timesteps"][wid].asUInt();
223  else
224  timesteps = json["timesteps"].asUInt();
225  }
226 
227  std::string name = "umbrella.dat";
228  if(json["output_file"].isArray())
229  name = json["output_file"][wid].asString();
230  else if(ismulti)
231  throw std::invalid_argument(path + ": Multi-walker simulations require a separate output file for each.");
232  else
233  name = json["output_file"].asString();
234 
235  Umbrella* m = nullptr;
236  if(timesteps == 0)
237  m = new Umbrella(world, comm, ksprings[wid], centers0[wid], name, freq);
238  else
239  m = new Umbrella(world, comm, ksprings[wid], centers0[wid], centers1[wid], timesteps, name, freq);
240 
241  m->SetOutputFrequency(json.get("output_frequency",0).asInt());
242  m->SetAppend(json.get("append", false).asBool());
243 
244  return m;
245  }
246 }
bool HasErrors()
Check if errors have occured.
Definition: Requirement.h:86
Collective variable manager.
Definition: CVManager.h:40
int GetIteration() const
Get the current iteration.
Definition: Snapshot.h:103
const Matrix3 & GetVirial() const
Get box virial.
Definition: Snapshot.h:133
std::vector< CollectiveVariable * > CVList
List of Collective Variables.
Definition: types.h:51
void SetOutputFrequency(int outfreq)
Set output frequency.
Definition: Umbrella.h:153
Class containing a snapshot of the current simulation in time.
Definition: Snapshot.h:43
Umbrella sampling method.
Definition: Umbrella.h:35
virtual void Parse(Value json, const std::string &path) override
Parse JSON value to generate Requirement(s).
Exception to be thrown when building the Driver fails.
std::vector< std::string > GetErrors()
Get list of error messages.
Definition: Requirement.h:92
std::vector< CollectiveVariable * > GetCVs(const std::vector< uint > &mask=std::vector< uint >()) const
Get CV iterator.
Definition: CVManager.h:80
Requirements on an object.
void SetAppend(bool append)
Set append mode.
Definition: Umbrella.h:162
const std::vector< Vector3 > & GetForces() const
Access the per-particle forces.
Definition: Snapshot.h:362
const Matrix3 & GetHMatrix() const
Get system H-matrix.
Definition: Snapshot.h:127
virtual void Validate(const Value &json, const std::string &path) override
Validate JSON value.